mirror of
https://github.com/0rangebananaspy/authelia.git
synced 2024-09-14 22:47:21 +07:00
feat(storage): primary key for all tables and general qol refactoring (#2431)
This is a massive overhaul to the SQL Storage for Authelia. It facilitates a whole heap of utility commands to help manage the database, primary keys, ensures all database requests use a context for cancellations, and paves the way for a few other PR's which improve the database. Fixes #1337
This commit is contained in:
parent
884dc99083
commit
3695aa8140
|
@ -1,14 +1,13 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/commands"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
)
|
||||
|
||||
func main() {
|
||||
logger := logging.Logger()
|
||||
|
||||
if err := commands.NewRootCmd().Execute(); err != nil {
|
||||
logger.Fatal(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -264,4 +264,5 @@ In versions <= `4.24.0` not including the `username_attribute` placeholder will
|
|||
and will result in session resets when the refresh interval has expired, default of 5 minutes.
|
||||
|
||||
[LDAP GeneralizedTime]: https://ldapwiki.com/wiki/GeneralizedTime
|
||||
[username attribute]: #username_attribute
|
||||
[TechNet wiki]: https://social.technet.microsoft.com/wiki/contents/articles/5392.active-directory-ldap-syntax-filters.aspx
|
||||
|
|
24
docs/configuration/storage/migrations.md
Normal file
24
docs/configuration/storage/migrations.md
Normal file
|
@ -0,0 +1,24 @@
|
|||
---
|
||||
layout: default
|
||||
title: Migrations
|
||||
parent: Storage Backends
|
||||
grand_parent: Configuration
|
||||
nav_order: 5
|
||||
---
|
||||
|
||||
Storage migrations are important for keeping your database compatible with Authelia. Authelia will automatically upgrade
|
||||
your schema on startup. However, if you wish to use an older version of Authelia you may be required to manually
|
||||
downgrade your schema with a version of Authelia that supports your current schema.
|
||||
|
||||
## Schema Version to Authelia Version map
|
||||
|
||||
This table contains a list of schema versions and the corresponding release of Authelia that shipped with that version.
|
||||
This means all Authelia versions between two schema versions use the first schema version.
|
||||
|
||||
For example for version pre1, it is used for all versions between it and the version 1 schema, so 4.0.0 to 4.32.2. In
|
||||
this instance if you wanted to downgrade to pre1 you would need to use an Authelia binary with version 4.33.0 or higher.
|
||||
|
||||
|Schema Version|Authelia Version|Notes |
|
||||
|:------------:|:--------------:|:----------------------------------------------------------:|
|
||||
|pre1 |4.0.0 |Downgrading to this version requires you use the --pre1 flag|
|
||||
|1 |4.33.0 | |
|
|
@ -99,3 +99,36 @@ This section has the required status of the value and must be one of `yes`, `no`
|
|||
depends on other configuration options. If it's situational the situational usage should be documented. This is
|
||||
immediately followed by the styles `.label`, `.label-config`, and a traffic lights color label, i.e. if yes `.label-red`,
|
||||
if no `.label-green`, or if situational `.label-yellow`.
|
||||
|
||||
### Storage
|
||||
This section outlines some rules for storage contributions. Including but not limited to migrations, schema rules, etc.
|
||||
|
||||
#### Migrations
|
||||
All migrations must have an up and down migration, preferably idempotent.
|
||||
|
||||
All migrations must be named in the following format:
|
||||
```text
|
||||
V<version>.<name>.<engine>.<direction>.sql
|
||||
```
|
||||
|
||||
##### version
|
||||
A 4 digit version number, should be in sequential order.
|
||||
|
||||
##### name
|
||||
A name containing alphanumeric characters, underscores (treated as spaces), hyphens, and no spaces.
|
||||
|
||||
##### engine
|
||||
The target engine for the migration, options are all, mysql, postgres, and sqlite.
|
||||
|
||||
#### Primary Key
|
||||
All tables must have a primary key. This primary key must be an integer with auto increment enabled, or in the case of
|
||||
PostgreSQL a serial type.
|
||||
|
||||
#### Table/Column Names
|
||||
Table and Column names must be in snake case format. This means they must have only lowercase letters, and have words
|
||||
seperated by underscores. The reasoning for this is that some database engines ignore case by default and this makes it
|
||||
easy to be consistent with the casing.
|
||||
|
||||
#### Context
|
||||
All database methods should include the context attribute so that database requests that are no longer needed are
|
||||
terminated appropriately.
|
||||
|
|
84
go.mod
84
go.mod
|
@ -1,15 +1,12 @@
|
|||
module github.com/authelia/authelia/v4
|
||||
|
||||
go 1.16
|
||||
go 1.17
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0
|
||||
github.com/Gurpartap/logrus-stack v0.0.0-20170710170904-89c00d8a28f4
|
||||
github.com/Workiva/go-datastructures v1.0.53
|
||||
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef
|
||||
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
|
||||
github.com/deckarep/golang-set v1.7.1
|
||||
github.com/duosecurity/duo_api_golang v0.0.0-20211027140842-72da735c6f15
|
||||
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect
|
||||
github.com/fasthttp/router v1.4.4
|
||||
github.com/fasthttp/session/v2 v2.4.4
|
||||
github.com/go-ldap/ldap/v3 v3.4.1
|
||||
|
@ -19,6 +16,7 @@ require (
|
|||
github.com/golang/mock v1.6.0
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/jackc/pgx/v4 v4.14.0
|
||||
github.com/jmoiron/sqlx v1.3.1
|
||||
github.com/knadh/koanf v1.3.2
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible
|
||||
github.com/mitchellh/mapstructure v1.4.2
|
||||
|
@ -30,16 +28,88 @@ require (
|
|||
github.com/simia-tech/crypt v0.5.0
|
||||
github.com/sirupsen/logrus v1.8.1
|
||||
github.com/spf13/cobra v1.2.1
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/stretchr/testify v1.7.0
|
||||
github.com/tstranex/u2f v1.0.0
|
||||
github.com/valyala/fasthttp v1.31.0
|
||||
golang.org/x/sys v0.0.0-20210902050250-f475640dd07b // indirect
|
||||
golang.org/x/text v0.3.7
|
||||
gopkg.in/square/go-jose.v2 v2.6.0
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c // indirect
|
||||
github.com/andybalholm/brotli v1.0.2 // indirect
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgraph-io/ristretto v0.1.0 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.0 // indirect
|
||||
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect
|
||||
github.com/fsnotify/fsnotify v1.4.9 // indirect
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.1 // indirect
|
||||
github.com/go-redis/redis/v8 v8.11.4 // indirect
|
||||
github.com/gobuffalo/pop/v5 v5.3.3 // indirect
|
||||
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect
|
||||
github.com/golang/protobuf v1.5.2 // indirect
|
||||
github.com/gorilla/websocket v1.4.2 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.0.0 // indirect
|
||||
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
|
||||
github.com/jackc/pgconn v1.10.1 // indirect
|
||||
github.com/jackc/pgio v1.0.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgproto3/v2 v2.2.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
|
||||
github.com/jackc/pgtype v1.9.0 // indirect
|
||||
github.com/jandelgado/gcov2lcov v1.0.4 // indirect
|
||||
github.com/klauspost/compress v1.13.4 // indirect
|
||||
github.com/magiconair/properties v1.8.5 // indirect
|
||||
github.com/mattn/goveralls v0.0.6 // indirect
|
||||
github.com/mitchellh/copystructure v1.2.0 // indirect
|
||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
|
||||
github.com/ory/go-acc v0.2.6 // indirect
|
||||
github.com/ory/go-convenience v0.1.0 // indirect
|
||||
github.com/ory/viper v1.7.5 // indirect
|
||||
github.com/ory/x v0.0.288 // indirect
|
||||
github.com/pborman/uuid v1.2.1 // indirect
|
||||
github.com/pelletier/go-toml v1.9.3 // indirect
|
||||
github.com/philhofer/fwd v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/savsgio/dictpool v0.0.0-20210921080634-84324d0689d7 // indirect
|
||||
github.com/savsgio/gotils v0.0.0-20210921075833-21a6215cb0e4 // indirect
|
||||
github.com/seatgeek/logrus-gelf-formatter v0.0.0-20210414080842-5b05eb8ff761 // indirect
|
||||
github.com/spf13/afero v1.6.0 // indirect
|
||||
github.com/spf13/cast v1.3.2-0.20200723214538-8d17101741c8 // indirect
|
||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||
github.com/sqs/goreturns v0.0.0-20181028201513-538ac6014518 // indirect
|
||||
github.com/subosito/gotenv v1.2.0 // indirect
|
||||
github.com/tinylib/msgp v1.1.6 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/ysmood/goob v0.3.0 // indirect
|
||||
github.com/ysmood/gson v0.6.4 // indirect
|
||||
github.com/ysmood/leakless v0.7.0 // indirect
|
||||
go.opentelemetry.io/contrib v0.20.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/httptrace/otelhttptrace v0.20.0 // indirect
|
||||
go.opentelemetry.io/otel v0.20.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v0.20.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v0.20.0 // indirect
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect
|
||||
golang.org/x/mod v0.4.2 // indirect
|
||||
golang.org/x/net v0.0.0-20210510120150-4163338589ed // indirect
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect
|
||||
golang.org/x/tools v0.1.2 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c // indirect
|
||||
google.golang.org/grpc v1.38.0 // indirect
|
||||
google.golang.org/protobuf v1.26.0 // indirect
|
||||
gopkg.in/ini.v1 v1.62.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
|
||||
)
|
||||
|
||||
replace (
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible => github.com/mattn/go-sqlite3 v1.14.8
|
||||
github.com/mattn/go-sqlite3 v2.0.3+incompatible => github.com/mattn/go-sqlite3 v1.14.9
|
||||
github.com/tidwall/gjson => github.com/tidwall/gjson v1.11.0
|
||||
)
|
||||
|
|
17
go.sum
17
go.sum
|
@ -45,8 +45,6 @@ github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c/go.mod h1:chxPXzS
|
|||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
|
||||
github.com/DataDog/datadog-go v4.0.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
|
||||
github.com/Gurpartap/logrus-stack v0.0.0-20170710170904-89c00d8a28f4 h1:vdT7QwBhJJEVNFMBNhRSFDRCB6O16T28VhvqRgqFyn8=
|
||||
github.com/Gurpartap/logrus-stack v0.0.0-20170710170904-89c00d8a28f4/go.mod h1:SvXOG8ElV28oAiG9zv91SDe5+9PfIr7PPccpr8YyXNs=
|
||||
|
@ -66,8 +64,6 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdko
|
|||
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
|
||||
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
|
||||
github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g=
|
||||
github.com/Workiva/go-datastructures v1.0.53 h1:J6Y/52yX10Xc5JjXmGtWoSSxs3mZnGSaq37xZZh7Yig=
|
||||
github.com/Workiva/go-datastructures v1.0.53/go.mod h1:1yZL+zfsztete+ePzZz/Zb1/t5BnDuE2Ya2MMGhzP6A=
|
||||
github.com/afex/hystrix-go v0.0.0-20180502004556-fa1af6a1f4f5/go.mod h1:SkGFH1ia65gfNATL8TAiHDNxPzPdmEL5uirI2Uyuz6c=
|
||||
github.com/ajg/form v0.0.0-20160822230020-523a5da1a92f/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
||||
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
|
||||
|
@ -91,8 +87,8 @@ github.com/asaskevich/govalidator v0.0.0-20180720115003-f9ffefc3facf/go.mod h1:l
|
|||
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
|
||||
github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg=
|
||||
github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg=
|
||||
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef h1:46PFijGLmAjMPwCCCo7Jf0W6f9slllCkkv7vyc1yOSg=
|
||||
github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
|
||||
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d h1:Byv0BzEl3/e6D5CLfI0j/7hiIEtvGVFPCZ7Ei2oq8iQ=
|
||||
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw=
|
||||
github.com/aws/aws-lambda-go v1.13.3/go.mod h1:4UKl9IzQMoD+QF79YdCuzCwp8VbmG4VAQwij/eHl5CU=
|
||||
github.com/aws/aws-sdk-go v1.23.19/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
|
||||
github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
|
||||
|
@ -825,6 +821,7 @@ github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHW
|
|||
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
|
||||
github.com/jmoiron/sqlx v0.0.0-20180614180643-0dae4fefe7c0/go.mod h1:IiEW3SEiiErVyFdH8NTuWjSifiEQKUoyK3LNqr2kCHU=
|
||||
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
|
||||
github.com/jmoiron/sqlx v1.3.1 h1:aLN7YINNZ7cYOPK3QC83dbM6KT0NMqVMw961TqrejlE=
|
||||
github.com/jmoiron/sqlx v1.3.1/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ=
|
||||
github.com/joeshaw/multierror v0.0.0-20140124173710-69b34d4ec901/go.mod h1:Z86h9688Y0wesXCyonoVr47MasHilkuLMqGhRZ4Hpak=
|
||||
github.com/joho/godotenv v1.2.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
|
||||
|
@ -948,8 +945,9 @@ github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOq
|
|||
github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
|
||||
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU=
|
||||
github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
github.com/mattn/go-sqlite3 v1.14.9 h1:10HX2Td0ocZpYEjhilsuo6WWtUqttj2Kb0KtD86/KYA=
|
||||
github.com/mattn/go-sqlite3 v1.14.9/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
|
||||
github.com/mattn/goveralls v0.0.6 h1:cr8Y0VMo/MnEZBjxNN/vh6G90SZ7IMb6lms1dzMoO+Y=
|
||||
github.com/mattn/goveralls v0.0.6/go.mod h1:h8b4ow6FxSPMQHF6o2ve3qsclnffZjYTNEKmLesRwqw=
|
||||
|
@ -1304,14 +1302,12 @@ github.com/tidwall/sjson v1.0.4/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7V
|
|||
github.com/tidwall/sjson v1.1.5 h1:wsUceI/XDyZk3J1FUvuuYlK62zJv2HO2Pzb8A5EWdUE=
|
||||
github.com/tidwall/sjson v1.1.5/go.mod h1:VuJzsZnTowhSxWdOgsAnb886i4AjEyTkk7tNtsL7EYE=
|
||||
github.com/tinylib/msgp v1.1.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
|
||||
github.com/tinylib/msgp v1.1.5/go.mod h1:eQsjooMTnV42mHu917E26IogZ2930nFyBQdofk10Udg=
|
||||
github.com/tinylib/msgp v1.1.6 h1:i+SbKraHhnrf9M5MYmvQhFnbLhAXSDWF8WWsuyRdocw=
|
||||
github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw=
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
|
||||
github.com/tstranex/u2f v1.0.0 h1:HhJkSzDDlVSVIVt7pDJwCHQj67k7A5EeBgPmeD+pVsQ=
|
||||
github.com/tstranex/u2f v1.0.0/go.mod h1:eahSLaqAS0zsIEv80+vXT7WanXs7MQQDg3j3wGBSayo=
|
||||
github.com/ttacon/chalk v0.0.0-20160626202418-22c06c80ed31/go.mod h1:onvgF043R+lC5RZ8IT9rBXDaEDnpnw/Cl+HFiw+v/7Q=
|
||||
github.com/uber-go/atomic v1.3.2/go.mod h1:/Ct5t2lcmbJ4OSe/waGBoaVvVqtO0bmtfVNex1PFV8g=
|
||||
github.com/uber/jaeger-client-go v2.15.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk=
|
||||
github.com/uber/jaeger-client-go v2.22.1+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk=
|
||||
|
@ -1681,9 +1677,8 @@ golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210902050250-f475640dd07b h1:S7hKs0Flbq0bbc9xgYt4stIEG1zNDFqyrPwAX2Wj/sE=
|
||||
golang.org/x/sys v0.0.0-20210902050250-f475640dd07b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/asaskevich/govalidator"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
|
@ -208,6 +207,6 @@ func (p *FileUserProvider) UpdatePassword(username string, newPassword string) e
|
|||
}
|
||||
|
||||
// StartupCheck implements the startup check provider interface.
|
||||
func (p *FileUserProvider) StartupCheck(_ *logrus.Logger) (err error) {
|
||||
func (p *FileUserProvider) StartupCheck() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ type LDAPUserProvider struct {
|
|||
configuration schema.LDAPAuthenticationBackendConfiguration
|
||||
tlsConfig *tls.Config
|
||||
dialOpts []ldap.DialOpt
|
||||
logger *logrus.Logger
|
||||
log *logrus.Logger
|
||||
connectionFactory LDAPConnectionFactory
|
||||
|
||||
disableResetPassword bool
|
||||
|
@ -72,7 +72,7 @@ func newLDAPUserProvider(configuration schema.LDAPAuthenticationBackendConfigura
|
|||
configuration: configuration,
|
||||
tlsConfig: tlsConfig,
|
||||
dialOpts: dialOpts,
|
||||
logger: logging.Logger(),
|
||||
log: logging.Logger(),
|
||||
connectionFactory: factory,
|
||||
disableResetPassword: disableResetPassword,
|
||||
}
|
||||
|
@ -148,7 +148,7 @@ func (p *LDAPUserProvider) resolveUsersFilter(inputUsername string) (filter stri
|
|||
filter = strings.ReplaceAll(filter, ldapPlaceholderInput, p.ldapEscape(inputUsername))
|
||||
}
|
||||
|
||||
p.logger.Tracef("Computed user filter is %s", filter)
|
||||
p.log.Tracef("Computed user filter is %s", filter)
|
||||
|
||||
return filter
|
||||
}
|
||||
|
@ -223,7 +223,7 @@ func (p *LDAPUserProvider) resolveGroupsFilter(inputUsername string, profile *ld
|
|||
}
|
||||
}
|
||||
|
||||
p.logger.Tracef("Computed groups filter is %s", filter)
|
||||
p.log.Tracef("Computed groups filter is %s", filter)
|
||||
|
||||
return filter, nil
|
||||
}
|
||||
|
@ -262,7 +262,7 @@ func (p *LDAPUserProvider) GetDetails(inputUsername string) (*UserDetails, error
|
|||
|
||||
for _, res := range sr.Entries {
|
||||
if len(res.Attributes) == 0 {
|
||||
p.logger.Warningf("No groups retrieved from LDAP for user %s", inputUsername)
|
||||
p.log.Warningf("No groups retrieved from LDAP for user %s", inputUsername)
|
||||
break
|
||||
}
|
||||
|
||||
|
|
|
@ -4,13 +4,12 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
||||
// StartupCheck implements the startup check provider interface.
|
||||
func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) {
|
||||
func (p *LDAPUserProvider) StartupCheck() (err error) {
|
||||
conn, err := p.connect(p.configuration.User, p.configuration.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -33,7 +32,7 @@ func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) {
|
|||
// Iterate the attribute values to see what the server supports.
|
||||
for _, attr := range sr.Entries[0].Attributes {
|
||||
if attr.Name == ldapSupportedExtensionAttribute {
|
||||
logger.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", "))
|
||||
p.log.Tracef("LDAP Supported Extension OIDs: %s", strings.Join(attr.Values, ", "))
|
||||
|
||||
for _, oid := range attr.Values {
|
||||
if oid == ldapOIDPasswdModifyExtension {
|
||||
|
@ -48,7 +47,7 @@ func (p *LDAPUserProvider) StartupCheck(logger *logrus.Logger) (err error) {
|
|||
|
||||
if !p.supportExtensionPasswdModify && !p.disableResetPassword &&
|
||||
p.configuration.Implementation != schema.LDAPImplementationActiveDirectory {
|
||||
logger.Warn("Your LDAP server implementation may not support a method for password hashing " +
|
||||
p.log.Warn("Your LDAP server implementation may not support a method for password hashing " +
|
||||
"known to Authelia, it's strongly recommended you ensure your directory server hashes the password " +
|
||||
"attribute when users reset their password via Authelia.")
|
||||
}
|
||||
|
@ -61,7 +60,7 @@ func (p *LDAPUserProvider) parseDynamicUsersConfiguration() {
|
|||
p.configuration.UsersFilter = strings.ReplaceAll(p.configuration.UsersFilter, "{mail_attribute}", p.configuration.MailAttribute)
|
||||
p.configuration.UsersFilter = strings.ReplaceAll(p.configuration.UsersFilter, "{display_name_attribute}", p.configuration.DisplayNameAttribute)
|
||||
|
||||
p.logger.Tracef("Dynamically generated users filter is %s", p.configuration.UsersFilter)
|
||||
p.log.Tracef("Dynamically generated users filter is %s", p.configuration.UsersFilter)
|
||||
|
||||
p.usersAttributes = []string{
|
||||
p.configuration.DisplayNameAttribute,
|
||||
|
@ -75,13 +74,13 @@ func (p *LDAPUserProvider) parseDynamicUsersConfiguration() {
|
|||
p.usersBaseDN = p.configuration.BaseDN
|
||||
}
|
||||
|
||||
p.logger.Tracef("Dynamically generated users BaseDN is %s", p.usersBaseDN)
|
||||
p.log.Tracef("Dynamically generated users BaseDN is %s", p.usersBaseDN)
|
||||
|
||||
if strings.Contains(p.configuration.UsersFilter, ldapPlaceholderInput) {
|
||||
p.usersFilterReplacementInput = true
|
||||
}
|
||||
|
||||
p.logger.Tracef("Detected user filter replacements that need to be resolved per lookup are: %s=%v",
|
||||
p.log.Tracef("Detected user filter replacements that need to be resolved per lookup are: %s=%v",
|
||||
ldapPlaceholderInput, p.usersFilterReplacementInput)
|
||||
}
|
||||
|
||||
|
@ -96,7 +95,7 @@ func (p *LDAPUserProvider) parseDynamicGroupsConfiguration() {
|
|||
p.groupsBaseDN = p.configuration.BaseDN
|
||||
}
|
||||
|
||||
p.logger.Tracef("Dynamically generated groups BaseDN is %s", p.groupsBaseDN)
|
||||
p.log.Tracef("Dynamically generated groups BaseDN is %s", p.groupsBaseDN)
|
||||
|
||||
if strings.Contains(p.configuration.GroupsFilter, ldapPlaceholderInput) {
|
||||
p.groupsFilterReplacementInput = true
|
||||
|
@ -110,5 +109,5 @@ func (p *LDAPUserProvider) parseDynamicGroupsConfiguration() {
|
|||
p.groupsFilterReplacementDN = true
|
||||
}
|
||||
|
||||
p.logger.Tracef("Detected group filter replacements that need to be resolved per lookup are: input=%v, username=%v, dn=%v", p.groupsFilterReplacementInput, p.groupsFilterReplacementUsername, p.groupsFilterReplacementDN)
|
||||
p.log.Tracef("Detected group filter replacements that need to be resolved per lookup are: input=%v, username=%v, dn=%v", p.groupsFilterReplacementInput, p.groupsFilterReplacementUsername, p.groupsFilterReplacementDN)
|
||||
}
|
||||
|
|
|
@ -12,7 +12,6 @@ import (
|
|||
"golang.org/x/text/encoding/unicode"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
|
@ -216,7 +215,7 @@ func TestShouldCheckLDAPServerExtensions(t *testing.T) {
|
|||
|
||||
gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
|
||||
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
err := ldapClient.StartupCheck()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, ldapClient.supportExtensionPasswdModify)
|
||||
|
@ -273,7 +272,7 @@ func TestShouldNotEnablePasswdModifyExtension(t *testing.T) {
|
|||
|
||||
gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
|
||||
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
err := ldapClient.StartupCheck()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.False(t, ldapClient.supportExtensionPasswdModify)
|
||||
|
@ -306,7 +305,7 @@ func TestShouldReturnCheckServerConnectError(t *testing.T) {
|
|||
DialURL(gomock.Eq("ldap://127.0.0.1:389"), gomock.Any()).
|
||||
Return(mockConn, errors.New("could not connect"))
|
||||
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
err := ldapClient.StartupCheck()
|
||||
assert.EqualError(t, err, "could not connect")
|
||||
|
||||
assert.False(t, ldapClient.supportExtensionPasswdModify)
|
||||
|
@ -351,7 +350,7 @@ func TestShouldReturnCheckServerSearchError(t *testing.T) {
|
|||
|
||||
gomock.InOrder(dialURL, connBind, searchOIDs, connClose)
|
||||
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
err := ldapClient.StartupCheck()
|
||||
assert.EqualError(t, err, "could not perform the search")
|
||||
|
||||
assert.False(t, ldapClient.supportExtensionPasswdModify)
|
||||
|
@ -755,7 +754,7 @@ func TestShouldUpdateUserPasswordPasswdModifyExtension(t *testing.T) {
|
|||
|
||||
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
|
||||
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
err := ldapClient.StartupCheck()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ldapClient.UpdatePassword("john", "password")
|
||||
|
@ -862,7 +861,7 @@ func TestShouldUpdateUserPasswordActiveDirectory(t *testing.T) {
|
|||
|
||||
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
|
||||
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
err := ldapClient.StartupCheck()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ldapClient.UpdatePassword("john", "password")
|
||||
|
@ -966,7 +965,7 @@ func TestShouldUpdateUserPasswordBasic(t *testing.T) {
|
|||
|
||||
gomock.InOrder(dialURLOIDs, connBindOIDs, searchOIDs, connCloseOIDs, dialURL, connBind, searchProfile, passwdModify, connClose)
|
||||
|
||||
err := ldapClient.StartupCheck(logging.Logger())
|
||||
err := ldapClient.StartupCheck()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ldapClient.UpdatePassword("john", "password")
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
package authentication
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
)
|
||||
|
||||
// UserProvider is the interface for checking user password and
|
||||
// gathering user details.
|
||||
type UserProvider interface {
|
||||
models.StartupCheck
|
||||
|
||||
CheckUserPassword(username string, password string) (valid bool, err error)
|
||||
GetDetails(username string) (details *UserDetails, err error)
|
||||
UpdatePassword(username string, newPassword string) (err error)
|
||||
StartupCheck(logger *logrus.Logger) (err error)
|
||||
}
|
||||
|
|
|
@ -75,3 +75,8 @@ PowerShell:
|
|||
PS> authelia completion powershell > authelia.ps1
|
||||
# and source this file from your PowerShell profile.
|
||||
`
|
||||
|
||||
const (
|
||||
storageMigrateDirectionUp = "up"
|
||||
storageMigrateDirectionDown = "down"
|
||||
)
|
||||
|
|
28
internal/commands/helpers.go
Normal file
28
internal/commands/helpers.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package commands
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/storage"
|
||||
)
|
||||
|
||||
func getStorageProvider() (provider storage.Provider, err error) {
|
||||
switch {
|
||||
case config.Storage.PostgreSQL != nil:
|
||||
provider = storage.NewPostgreSQLProvider(*config.Storage.PostgreSQL)
|
||||
case config.Storage.MySQL != nil:
|
||||
provider = storage.NewMySQLProvider(*config.Storage.MySQL)
|
||||
case config.Storage.Local != nil:
|
||||
provider = storage.NewSQLiteProvider(config.Storage.Local.Path)
|
||||
default:
|
||||
return nil, errors.New("no storage provider configured")
|
||||
}
|
||||
|
||||
if (config.Storage.MySQL != nil && config.Storage.PostgreSQL != nil) ||
|
||||
(config.Storage.MySQL != nil && config.Storage.Local != nil) ||
|
||||
(config.Storage.PostgreSQL != nil && config.Storage.Local != nil) {
|
||||
return nil, errors.New("multiple storage providers are configured but should only configure one")
|
||||
}
|
||||
|
||||
return provider, err
|
||||
}
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
"github.com/authelia/authelia/v4/internal/notification"
|
||||
"github.com/authelia/authelia/v4/internal/ntp"
|
||||
"github.com/authelia/authelia/v4/internal/oidc"
|
||||
|
@ -46,6 +47,7 @@ func NewRootCmd() (cmd *cobra.Command) {
|
|||
newCompletionCmd(),
|
||||
NewHashPasswordCmd(),
|
||||
NewRSACmd(),
|
||||
NewStorageCmd(),
|
||||
newValidateConfigCmd(),
|
||||
)
|
||||
|
||||
|
@ -101,9 +103,6 @@ func getProviders(config *schema.Configuration) (providers middlewares.Providers
|
|||
storageProvider = storage.NewMySQLProvider(*config.Storage.MySQL)
|
||||
case config.Storage.Local != nil:
|
||||
storageProvider = storage.NewSQLiteProvider(config.Storage.Local.Path)
|
||||
default:
|
||||
// TODO: Add storage provider startup check and remove this.
|
||||
errors = append(errors, fmt.Errorf("unrecognized storage provider"))
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -162,6 +161,12 @@ func doStartupChecks(config *schema.Configuration, providers *middlewares.Provid
|
|||
err error
|
||||
)
|
||||
|
||||
if err = doStartupCheck(logger, "storage", providers.StorageProvider, false); err != nil {
|
||||
logger.Errorf("Failure running the storage provider startup check: %+v", err)
|
||||
|
||||
failures = append(failures, "storage")
|
||||
}
|
||||
|
||||
if err = doStartupCheck(logger, "user", providers.UserProvider, false); err != nil {
|
||||
logger.Errorf("Failure running the user provider startup check: %+v", err)
|
||||
|
||||
|
@ -187,7 +192,7 @@ func doStartupChecks(config *schema.Configuration, providers *middlewares.Provid
|
|||
}
|
||||
}
|
||||
|
||||
func doStartupCheck(logger *logrus.Logger, name string, provider middlewares.ProviderWithStartupCheck, disabled bool) (err error) {
|
||||
func doStartupCheck(logger *logrus.Logger, name string, provider models.StartupCheck, disabled bool) (err error) {
|
||||
if disabled {
|
||||
logger.Debugf("%s provider: startup check skipped as it is disabled", name)
|
||||
return nil
|
||||
|
@ -197,7 +202,7 @@ func doStartupCheck(logger *logrus.Logger, name string, provider middlewares.Pro
|
|||
return fmt.Errorf("unrecognized provider or it is not configured properly")
|
||||
}
|
||||
|
||||
if err = provider.StartupCheck(logger); err != nil {
|
||||
if err = provider.StartupCheck(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
126
internal/commands/storage.go
Normal file
126
internal/commands/storage.go
Normal file
|
@ -0,0 +1,126 @@
|
|||
package commands
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// NewStorageCmd returns a new storage *cobra.Command.
|
||||
func NewStorageCmd() (cmd *cobra.Command) {
|
||||
cmd = &cobra.Command{
|
||||
Use: "storage",
|
||||
Short: "Manage the Authelia storage",
|
||||
Args: cobra.NoArgs,
|
||||
PersistentPreRunE: storagePersistentPreRunE,
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().StringSliceP("config", "c", []string{"config.yml"}, "configuration file to load for the storage migration")
|
||||
|
||||
cmd.PersistentFlags().String("sqlite.path", "", "the SQLite database path")
|
||||
|
||||
cmd.PersistentFlags().String("mysql.host", "", "the MySQL hostname")
|
||||
cmd.PersistentFlags().Int("mysql.port", 3306, "the MySQL port")
|
||||
cmd.PersistentFlags().String("mysql.database", "authelia", "the MySQL database name")
|
||||
cmd.PersistentFlags().String("mysql.username", "authelia", "the MySQL username")
|
||||
cmd.PersistentFlags().String("mysql.password", "", "the MySQL password")
|
||||
|
||||
cmd.PersistentFlags().String("postgres.host", "", "the PostgreSQL hostname")
|
||||
cmd.PersistentFlags().Int("postgres.port", 5432, "the PostgreSQL port")
|
||||
cmd.PersistentFlags().String("postgres.database", "authelia", "the PostgreSQL database name")
|
||||
cmd.PersistentFlags().String("postgres.username", "authelia", "the PostgreSQL username")
|
||||
cmd.PersistentFlags().String("postgres.password", "", "the PostgreSQL password")
|
||||
|
||||
cmd.AddCommand(
|
||||
newStorageMigrateCmd(),
|
||||
newStorageSchemaInfoCmd(),
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newStorageSchemaInfoCmd() (cmd *cobra.Command) {
|
||||
cmd = &cobra.Command{
|
||||
Use: "schema-info",
|
||||
Short: "Show the storage information",
|
||||
RunE: storageSchemaInfoRunE,
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// NewMigrationCmd returns a new Migration Cmd.
|
||||
func newStorageMigrateCmd() (cmd *cobra.Command) {
|
||||
cmd = &cobra.Command{
|
||||
Use: "migrate",
|
||||
Short: "Perform or list migrations",
|
||||
Args: cobra.NoArgs,
|
||||
}
|
||||
|
||||
cmd.AddCommand(
|
||||
newStorageMigrateUpCmd(), newStorageMigrateDownCmd(),
|
||||
newStorageMigrateListUpCmd(), newStorageMigrateListDownCmd(),
|
||||
newStorageMigrateHistoryCmd(),
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newStorageMigrateHistoryCmd() (cmd *cobra.Command) {
|
||||
cmd = &cobra.Command{
|
||||
Use: "history",
|
||||
Short: "Show migration history",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: storageMigrateHistoryRunE,
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newStorageMigrateListUpCmd() (cmd *cobra.Command) {
|
||||
cmd = &cobra.Command{
|
||||
Use: "list-up",
|
||||
Short: "List the up migrations available",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: newStorageMigrateListRunE(true),
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newStorageMigrateListDownCmd() (cmd *cobra.Command) {
|
||||
cmd = &cobra.Command{
|
||||
Use: "list-down",
|
||||
Short: "List the down migrations available",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: newStorageMigrateListRunE(false),
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newStorageMigrateUpCmd() (cmd *cobra.Command) {
|
||||
cmd = &cobra.Command{
|
||||
Use: storageMigrateDirectionUp,
|
||||
Short: "Perform a migration up",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: newStorageMigrationRunE(true),
|
||||
}
|
||||
|
||||
cmd.Flags().IntP("target", "t", 0, "sets the version to migrate to, by default this is the latest version")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newStorageMigrateDownCmd() (cmd *cobra.Command) {
|
||||
cmd = &cobra.Command{
|
||||
Use: storageMigrateDirectionDown,
|
||||
Short: "Perform a migration down",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: newStorageMigrationRunE(false),
|
||||
}
|
||||
|
||||
cmd.Flags().IntP("target", "t", 0, "sets the version to migrate to")
|
||||
cmd.Flags().Bool("pre1", false, "sets pre1 as the version to migrate to")
|
||||
cmd.Flags().Bool("destroy-data", false, "confirms you want to destroy data with this migration")
|
||||
|
||||
return cmd
|
||||
}
|
291
internal/commands/storage_run.go
Normal file
291
internal/commands/storage_run.go
Normal file
|
@ -0,0 +1,291 @@
|
|||
package commands
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration"
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/configuration/validator"
|
||||
"github.com/authelia/authelia/v4/internal/storage"
|
||||
)
|
||||
|
||||
func storagePersistentPreRunE(cmd *cobra.Command, _ []string) (err error) {
|
||||
configs, err := cmd.Flags().GetStringSlice("config")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sources := make([]configuration.Source, 0, len(configs)+3)
|
||||
|
||||
if cmd.Flags().Changed("config") {
|
||||
for _, configFile := range configs {
|
||||
if _, err := os.Stat(configFile); os.IsNotExist(err) {
|
||||
return fmt.Errorf("could not load the provided configuration file %s: %w", configFile, err)
|
||||
}
|
||||
|
||||
sources = append(sources, configuration.NewYAMLFileSource(configFile))
|
||||
}
|
||||
} else {
|
||||
if _, err := os.Stat(configs[0]); err == nil {
|
||||
sources = append(sources, configuration.NewYAMLFileSource(configs[0]))
|
||||
}
|
||||
}
|
||||
|
||||
mapping := map[string]string{
|
||||
"sqlite.path": "storage.local.path",
|
||||
"mysql.host": "storage.mysql.host",
|
||||
"mysql.port": "storage.mysql.port",
|
||||
"mysql.database": "storage.mysql.database",
|
||||
"mysql.username": "storage.mysql.username",
|
||||
"mysql.password": "storage.mysql.password",
|
||||
"postgres.host": "storage.postgres.host",
|
||||
"postgres.port": "storage.postgres.port",
|
||||
"postgres.database": "storage.postgres.database",
|
||||
"postgres.username": "storage.postgres.username",
|
||||
"postgres.password": "storage.postgres.password",
|
||||
"postgres.schema": "storage.postgres.schema",
|
||||
}
|
||||
|
||||
sources = append(sources, configuration.NewEnvironmentSource(configuration.DefaultEnvPrefix, configuration.DefaultEnvDelimiter))
|
||||
sources = append(sources, configuration.NewSecretsSource(configuration.DefaultEnvPrefix, configuration.DefaultEnvDelimiter))
|
||||
sources = append(sources, configuration.NewCommandLineSourceWithMapping(cmd.Flags(), mapping, true, false))
|
||||
|
||||
val := schema.NewStructValidator()
|
||||
|
||||
config = &schema.Configuration{}
|
||||
|
||||
_, err = configuration.LoadAdvanced(val, "storage", &config.Storage, sources...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val.HasErrors() {
|
||||
var finalErr error
|
||||
|
||||
for i, err := range val.Errors() {
|
||||
if i == 0 {
|
||||
finalErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
finalErr = fmt.Errorf("%w, %v", finalErr, err)
|
||||
}
|
||||
|
||||
return finalErr
|
||||
}
|
||||
|
||||
validator.ValidateStorage(config.Storage, val)
|
||||
|
||||
if val.HasErrors() {
|
||||
var finalErr error
|
||||
|
||||
for i, err := range val.Errors() {
|
||||
if i == 0 {
|
||||
finalErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
finalErr = fmt.Errorf("%w, %v", finalErr, err)
|
||||
}
|
||||
|
||||
return finalErr
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func storageMigrateHistoryRunE(_ *cobra.Command, _ []string) (err error) {
|
||||
var (
|
||||
provider storage.Provider
|
||||
ctx = context.Background()
|
||||
)
|
||||
|
||||
provider, err = getStorageProvider()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrations, err := provider.SchemaMigrationHistory(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(migrations) == 0 {
|
||||
return errors.New("no migration history found which may indicate a broken schema")
|
||||
}
|
||||
|
||||
fmt.Printf("Migration History:\n\nID\tDate\t\t\t\tBefore\tAfter\tAuthelia Version\n")
|
||||
|
||||
for _, m := range migrations {
|
||||
fmt.Printf("%d\t%s\t%d\t%d\t%s\n", m.ID, m.Applied.Format("2006-01-02 15:04:05 -0700"), m.Before, m.After, m.Version)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newStorageMigrateListRunE(up bool) func(cmd *cobra.Command, args []string) (err error) {
|
||||
return func(cmd *cobra.Command, args []string) (err error) {
|
||||
var (
|
||||
provider storage.Provider
|
||||
ctx = context.Background()
|
||||
migrations []storage.SchemaMigration
|
||||
directionStr string
|
||||
)
|
||||
|
||||
provider, err = getStorageProvider()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if up {
|
||||
migrations, err = provider.SchemaMigrationsUp(ctx, 0)
|
||||
directionStr = "Up"
|
||||
} else {
|
||||
migrations, err = provider.SchemaMigrationsDown(ctx, 0)
|
||||
directionStr = "Down"
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err.Error() == "cannot migrate to the same version as prior" {
|
||||
fmt.Printf("No %s migrations found\n", directionStr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
if len(migrations) == 0 {
|
||||
fmt.Printf("Storage Schema Migration List (%s)\n\nNo Migrations Available\n", directionStr)
|
||||
} else {
|
||||
fmt.Printf("Storage Schema Migration List (%s)\n\nVersion\t\tDescription\n", directionStr)
|
||||
|
||||
for _, migration := range migrations {
|
||||
fmt.Printf("%d\t\t%s\n", migration.Version, migration.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func newStorageMigrationRunE(up bool) func(cmd *cobra.Command, args []string) (err error) {
|
||||
return func(cmd *cobra.Command, args []string) (err error) {
|
||||
var (
|
||||
provider storage.Provider
|
||||
ctx = context.Background()
|
||||
)
|
||||
|
||||
provider, err = getStorageProvider()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
target, err := cmd.Flags().GetInt("target")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case up:
|
||||
switch cmd.Flags().Changed("target") {
|
||||
case true:
|
||||
return provider.SchemaMigrate(ctx, true, target)
|
||||
default:
|
||||
return provider.SchemaMigrate(ctx, true, storage.SchemaLatest)
|
||||
}
|
||||
default:
|
||||
if !cmd.Flags().Changed("target") {
|
||||
return errors.New("must set target")
|
||||
}
|
||||
|
||||
if err = storageMigrateDownConfirmDestroy(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pre1, err := cmd.Flags().GetBool("pre1")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case pre1:
|
||||
return provider.SchemaMigrate(ctx, false, -1)
|
||||
default:
|
||||
return provider.SchemaMigrate(ctx, false, target)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func storageMigrateDownConfirmDestroy(cmd *cobra.Command) (err error) {
|
||||
destroy, err := cmd.Flags().GetBool("destroy-data")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !destroy {
|
||||
fmt.Printf("Schema Down Migrations may DESTROY data, type 'DESTROY' and press return to continue: ")
|
||||
|
||||
var text string
|
||||
|
||||
_, _ = fmt.Scanln(&text)
|
||||
|
||||
if text != "DESTROY" {
|
||||
return errors.New("cancelling down migration due to user not accepting data destruction")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func storageSchemaInfoRunE(_ *cobra.Command, _ []string) (err error) {
|
||||
var (
|
||||
provider storage.Provider
|
||||
ctx = context.Background()
|
||||
upgradeStr string
|
||||
tablesStr string
|
||||
)
|
||||
|
||||
provider, err = getStorageProvider()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
version, err := provider.SchemaVersion(ctx)
|
||||
if err != nil && err.Error() != "unknown schema state" {
|
||||
return err
|
||||
}
|
||||
|
||||
tables, err := provider.SchemaTables(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(tables) == 0 {
|
||||
tablesStr = "N/A"
|
||||
} else {
|
||||
tablesStr = strings.Join(tables, ", ")
|
||||
}
|
||||
|
||||
latest, err := provider.SchemaLatestVersion()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if latest > version {
|
||||
upgradeStr = fmt.Sprintf("yes - version %d", latest)
|
||||
} else {
|
||||
upgradeStr = "no"
|
||||
}
|
||||
|
||||
fmt.Printf("Schema Version: %s\nSchema Upgrade Available: %s\nSchema Tables: %s\n", storage.SchemaVersionToString(version), upgradeStr, tablesStr)
|
||||
|
||||
return nil
|
||||
}
|
|
@ -4,6 +4,8 @@ import (
|
|||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/configuration/validator"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
|
@ -48,3 +50,25 @@ func koanfEnvironmentSecretsCallback(keyMap map[string]string, validator *schema
|
|||
return k, v
|
||||
}
|
||||
}
|
||||
|
||||
func koanfCommandLineWithMappingCallback(mapping map[string]string, includeValidKeys, includeUnchangedKeys bool) func(flag *pflag.Flag) (string, interface{}) {
|
||||
return func(flag *pflag.Flag) (string, interface{}) {
|
||||
if !includeUnchangedKeys && !flag.Changed {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if actualKey, ok := mapping[flag.Name]; ok {
|
||||
return actualKey, flag.Value.String()
|
||||
}
|
||||
|
||||
if includeValidKeys {
|
||||
formattedKey := strings.ReplaceAll(flag.Name, "-", "_")
|
||||
|
||||
if utils.IsStringInSlice(formattedKey, validator.ValidKeys) {
|
||||
return formattedKey, flag.Value.String()
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,8 +11,17 @@ import (
|
|||
|
||||
// Load the configuration given the provided options and sources.
|
||||
func Load(val *schema.StructValidator, sources ...Source) (keys []string, configuration *schema.Configuration, err error) {
|
||||
configuration = &schema.Configuration{}
|
||||
|
||||
keys, err = LoadAdvanced(val, "", configuration, sources...)
|
||||
|
||||
return keys, configuration, err
|
||||
}
|
||||
|
||||
// LoadAdvanced is intended to give more flexibility over loading a particular path to a specific interface.
|
||||
func LoadAdvanced(val *schema.StructValidator, path string, result interface{}, sources ...Source) (keys []string, err error) {
|
||||
if val == nil {
|
||||
return keys, configuration, errNoValidator
|
||||
return keys, errNoValidator
|
||||
}
|
||||
|
||||
ko := koanf.NewWithConf(koanf.Conf{
|
||||
|
@ -22,14 +31,12 @@ func Load(val *schema.StructValidator, sources ...Source) (keys []string, config
|
|||
|
||||
err = loadSources(ko, val, sources...)
|
||||
if err != nil {
|
||||
return ko.Keys(), configuration, err
|
||||
return ko.Keys(), err
|
||||
}
|
||||
|
||||
configuration = &schema.Configuration{}
|
||||
unmarshal(ko, val, path, result)
|
||||
|
||||
unmarshal(ko, val, "", configuration)
|
||||
|
||||
return ko.Keys(), configuration, nil
|
||||
return ko.Keys(), nil
|
||||
}
|
||||
|
||||
func unmarshal(ko *koanf.Koanf, val *schema.StructValidator, path string, o interface{}) {
|
||||
|
|
|
@ -1,12 +1,5 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/Workiva/go-datastructures/queue"
|
||||
)
|
||||
|
||||
// ErrorContainer represents a container where we can add errors and retrieve them.
|
||||
type ErrorContainer interface {
|
||||
Push(err error)
|
||||
|
@ -17,100 +10,6 @@ type ErrorContainer interface {
|
|||
Warnings() []error
|
||||
}
|
||||
|
||||
// Validator represents the validator interface.
|
||||
type Validator struct {
|
||||
errors map[string][]error
|
||||
}
|
||||
|
||||
// NewValidator create a validator.
|
||||
func NewValidator() *Validator {
|
||||
validator := new(Validator)
|
||||
validator.errors = make(map[string][]error)
|
||||
|
||||
return validator
|
||||
}
|
||||
|
||||
// QueueItem an item representing a struct field and its path.
|
||||
type QueueItem struct {
|
||||
value reflect.Value
|
||||
path string
|
||||
}
|
||||
|
||||
func (v *Validator) validateOne(item QueueItem, q *queue.Queue) error { //nolint:unparam
|
||||
if item.value.Type().Kind() == reflect.Ptr {
|
||||
if item.value.IsNil() {
|
||||
return nil
|
||||
}
|
||||
|
||||
elem := item.value.Elem()
|
||||
|
||||
q.Put(QueueItem{ //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
|
||||
value: elem,
|
||||
path: item.path,
|
||||
})
|
||||
} else if item.value.Kind() == reflect.Struct {
|
||||
numFields := item.value.Type().NumField()
|
||||
|
||||
validateFn := item.value.Addr().MethodByName("Validate")
|
||||
|
||||
if validateFn.IsValid() {
|
||||
structValidator := NewStructValidator()
|
||||
validateFn.Call([]reflect.Value{reflect.ValueOf(structValidator)})
|
||||
v.errors[item.path] = structValidator.Errors()
|
||||
}
|
||||
|
||||
for i := 0; i < numFields; i++ {
|
||||
field := item.value.Type().Field(i)
|
||||
value := item.value.Field(i)
|
||||
|
||||
q.Put(QueueItem{ //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
|
||||
value: value,
|
||||
path: item.path + "." + field.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validate a struct.
|
||||
func (v *Validator) Validate(s interface{}) error {
|
||||
q := queue.New(40)
|
||||
q.Put(QueueItem{value: reflect.ValueOf(s), path: "root"}) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
|
||||
|
||||
for !q.Empty() {
|
||||
val, err := q.Get(1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
item, ok := val[0].(QueueItem)
|
||||
if !ok {
|
||||
return fmt.Errorf("Cannot convert item into QueueItem")
|
||||
}
|
||||
|
||||
v.validateOne(item, q) //nolint:errcheck // TODO: Legacy code, consider refactoring time permitting.
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrintErrors display the errors thrown during validation.
|
||||
func (v *Validator) PrintErrors() {
|
||||
for path, errs := range v.errors {
|
||||
fmt.Printf("Errors at %s:\n", path)
|
||||
|
||||
for _, err := range errs {
|
||||
fmt.Printf("--> %s\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Errors return the errors thrown during validation.
|
||||
func (v *Validator) Errors() map[string][]error {
|
||||
return v.errors
|
||||
}
|
||||
|
||||
// StructValidator is a validator for structs.
|
||||
type StructValidator struct {
|
||||
errors []error
|
||||
|
|
|
@ -49,43 +49,6 @@ func (ts *TestStruct) Validate(validator *schema.StructValidator) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestValidator(t *testing.T) {
|
||||
validator := schema.NewValidator()
|
||||
|
||||
s := TestStruct{
|
||||
MustBe10: 5,
|
||||
NotEmpty: "",
|
||||
NestedPtr: &TestNestedStruct{},
|
||||
}
|
||||
|
||||
err := validator.Validate(&s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
errs := validator.Errors()
|
||||
assert.Equal(t, 4, len(errs))
|
||||
|
||||
assert.Equal(t, 2, len(errs["root"]))
|
||||
assert.ElementsMatch(t, []error{
|
||||
fmt.Errorf("MustBe10 must be 10"),
|
||||
fmt.Errorf("NotEmpty must not be empty")}, errs["root"])
|
||||
|
||||
assert.Equal(t, 1, len(errs["root.Nested"]))
|
||||
assert.ElementsMatch(t, []error{
|
||||
fmt.Errorf("MustBe5 must be 5")}, errs["root.Nested"])
|
||||
|
||||
assert.Equal(t, 1, len(errs["root.Nested2"]))
|
||||
assert.ElementsMatch(t, []error{
|
||||
fmt.Errorf("MustBe5 must be 5")}, errs["root.Nested2"])
|
||||
|
||||
assert.Equal(t, 1, len(errs["root.NestedPtr"]))
|
||||
assert.ElementsMatch(t, []error{
|
||||
fmt.Errorf("MustBe5 must be 5")}, errs["root.NestedPtr"])
|
||||
|
||||
assert.Equal(t, "xyz", s.SetDefault)
|
||||
}
|
||||
|
||||
func TestStructValidator(t *testing.T) {
|
||||
validator := schema.NewStructValidator()
|
||||
s := TestStruct{
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
"github.com/knadh/koanf/parsers/yaml"
|
||||
"github.com/knadh/koanf/providers/env"
|
||||
"github.com/knadh/koanf/providers/file"
|
||||
"github.com/knadh/koanf/providers/posflag"
|
||||
"github.com/spf13/pflag"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/configuration/validator"
|
||||
|
@ -112,6 +114,37 @@ func (s *SecretsSource) Load(val *schema.StructValidator) (err error) {
|
|||
return s.koanf.Load(env.ProviderWithValue(s.prefix, constDelimiter, koanfEnvironmentSecretsCallback(keyMap, val)), nil)
|
||||
}
|
||||
|
||||
// NewCommandLineSourceWithMapping creates a new command line configuration source with a map[string]string which converts
|
||||
// flag names into other config key names. If includeValidKeys is true we also allow any flag with a name which matches
|
||||
// the list of valid keys into the koanf.Koanf, otherwise everything not in the map is skipped. Unchanged flags are also
|
||||
// skipped unless includeUnchangedKeys is set to true.
|
||||
func NewCommandLineSourceWithMapping(flags *pflag.FlagSet, mapping map[string]string, includeValidKeys, includeUnchangedKeys bool) (source *CommandLineSource) {
|
||||
return &CommandLineSource{
|
||||
koanf: koanf.New(constDelimiter),
|
||||
flags: flags,
|
||||
callback: koanfCommandLineWithMappingCallback(mapping, includeValidKeys, includeUnchangedKeys),
|
||||
}
|
||||
}
|
||||
|
||||
// Name of the Source.
|
||||
func (s CommandLineSource) Name() (name string) {
|
||||
return "command-line"
|
||||
}
|
||||
|
||||
// Merge the CommandLineSource koanf.Koanf into the provided one.
|
||||
func (s *CommandLineSource) Merge(ko *koanf.Koanf, val *schema.StructValidator) (err error) {
|
||||
return ko.Merge(s.koanf)
|
||||
}
|
||||
|
||||
// Load the Source into the YAMLFileSource koanf.Koanf.
|
||||
func (s *CommandLineSource) Load(_ *schema.StructValidator) (err error) {
|
||||
if s.callback != nil {
|
||||
return s.koanf.Load(posflag.ProviderWithFlag(s.flags, ".", s.koanf, s.callback), nil)
|
||||
}
|
||||
|
||||
return s.koanf.Load(posflag.Provider(s.flags, ".", s.koanf), nil)
|
||||
}
|
||||
|
||||
// NewDefaultSources returns a slice of Source configured to load from specified YAML files.
|
||||
func NewDefaultSources(filePaths []string, prefix, delimiter string) (sources []Source) {
|
||||
fileSources := NewYAMLFileSources(filePaths)
|
||||
|
|
|
@ -2,6 +2,7 @@ package configuration
|
|||
|
||||
import (
|
||||
"github.com/knadh/koanf"
|
||||
"github.com/spf13/pflag"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
@ -32,3 +33,10 @@ type SecretsSource struct {
|
|||
prefix string
|
||||
delimiter string
|
||||
}
|
||||
|
||||
// CommandLineSource loads configuration from the command line flags.
|
||||
type CommandLineSource struct {
|
||||
koanf *koanf.Koanf
|
||||
flags *pflag.FlagSet
|
||||
callback func(flag *pflag.Flag) (string, interface{})
|
||||
}
|
||||
|
|
|
@ -73,6 +73,12 @@ const (
|
|||
pathOpenIDConnectConsent = "/api/oidc/consent"
|
||||
)
|
||||
|
||||
const (
|
||||
totpAlgoSHA1 = "SHA1"
|
||||
totpAlgoSHA256 = "SHA256"
|
||||
totpAlgoSHA512 = "SHA512"
|
||||
)
|
||||
|
||||
const (
|
||||
accept = "accept"
|
||||
reject = "reject"
|
||||
|
|
|
@ -77,7 +77,7 @@ func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middleware
|
|||
return
|
||||
}
|
||||
|
||||
bannedUntil, err := ctx.Providers.Regulator.Regulate(bodyJSON.Username)
|
||||
bannedUntil, err := ctx.Providers.Regulator.Regulate(ctx, bodyJSON.Username)
|
||||
|
||||
if err != nil {
|
||||
if err == regulation.ErrUserIsBanned {
|
||||
|
@ -95,7 +95,7 @@ func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middleware
|
|||
if err != nil {
|
||||
ctx.Logger.Debugf("Mark authentication attempt made by user %s", bodyJSON.Username)
|
||||
|
||||
if err := ctx.Providers.Regulator.Mark(bodyJSON.Username, false); err != nil {
|
||||
if err := ctx.Providers.Regulator.Mark(ctx, bodyJSON.Username, false); err != nil {
|
||||
ctx.Logger.Errorf("Unable to mark authentication: %s", err.Error())
|
||||
}
|
||||
|
||||
|
@ -107,7 +107,7 @@ func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middleware
|
|||
if !userPasswordOk {
|
||||
ctx.Logger.Debugf("Mark authentication attempt made by user %s", bodyJSON.Username)
|
||||
|
||||
if err := ctx.Providers.Regulator.Mark(bodyJSON.Username, false); err != nil {
|
||||
if err := ctx.Providers.Regulator.Mark(ctx, bodyJSON.Username, false); err != nil {
|
||||
ctx.Logger.Errorf("Unable to mark authentication: %s", err.Error())
|
||||
}
|
||||
|
||||
|
@ -117,7 +117,7 @@ func FirstFactorPost(msInitialDelay time.Duration, delayEnabled bool) middleware
|
|||
}
|
||||
|
||||
ctx.Logger.Debugf("Mark authentication attempt made by user %s", bodyJSON.Username)
|
||||
err = ctx.Providers.Regulator.Mark(bodyJSON.Username, true)
|
||||
err = ctx.Providers.Regulator.Mark(ctx, bodyJSON.Username, true)
|
||||
|
||||
if err != nil {
|
||||
handleAuthenticationUnauthorized(ctx, fmt.Errorf("unable to mark authentication: %s", err.Error()), messageAuthenticationFailed)
|
||||
|
|
|
@ -58,7 +58,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderCheckPasswordFail() {
|
|||
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
AppendAuthenticationLog(gomock.Eq(models.AuthenticationAttempt{
|
||||
AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{
|
||||
Username: "test",
|
||||
Successful: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
|
@ -83,7 +83,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCrede
|
|||
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
AppendAuthenticationLog(gomock.Eq(models.AuthenticationAttempt{
|
||||
AppendAuthenticationLog(s.mock.Ctx, gomock.Eq(models.AuthenticationAttempt{
|
||||
Username: "test",
|
||||
Successful: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
|
@ -106,7 +106,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderGetDetailsFail() {
|
|||
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
AppendAuthenticationLog(gomock.Any()).
|
||||
AppendAuthenticationLog(s.mock.Ctx, gomock.Any()).
|
||||
Return(nil)
|
||||
|
||||
s.mock.UserProviderMock.
|
||||
|
@ -133,7 +133,7 @@ func (s *FirstFactorSuite) TestShouldFailIfAuthenticationMarkFail() {
|
|||
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
AppendAuthenticationLog(gomock.Any()).
|
||||
AppendAuthenticationLog(s.mock.Ctx, gomock.Any()).
|
||||
Return(fmt.Errorf("failed"))
|
||||
|
||||
s.mock.Ctx.Request.SetBodyString(`{
|
||||
|
@ -164,7 +164,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeChecked() {
|
|||
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
AppendAuthenticationLog(gomock.Any()).
|
||||
AppendAuthenticationLog(s.mock.Ctx, gomock.Any()).
|
||||
Return(nil)
|
||||
|
||||
s.mock.Ctx.Request.SetBodyString(`{
|
||||
|
@ -204,7 +204,7 @@ func (s *FirstFactorSuite) TestShouldAuthenticateUserWithRememberMeUnchecked() {
|
|||
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
AppendAuthenticationLog(gomock.Any()).
|
||||
AppendAuthenticationLog(s.mock.Ctx, gomock.Any()).
|
||||
Return(nil)
|
||||
|
||||
s.mock.Ctx.Request.SetBodyString(`{
|
||||
|
@ -248,7 +248,7 @@ func (s *FirstFactorSuite) TestShouldSaveUsernameFromAuthenticationBackendInSess
|
|||
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
AppendAuthenticationLog(gomock.Any()).
|
||||
AppendAuthenticationLog(s.mock.Ctx, gomock.Any()).
|
||||
Return(nil)
|
||||
|
||||
s.mock.Ctx.Request.SetBodyString(`{
|
||||
|
@ -306,7 +306,7 @@ func (s *FirstFactorRedirectionSuite) SetupTest() {
|
|||
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
AppendAuthenticationLog(gomock.Any()).
|
||||
AppendAuthenticationLog(s.mock.Ctx, gomock.Any()).
|
||||
Return(nil)
|
||||
}
|
||||
|
||||
|
|
|
@ -3,9 +3,11 @@ package handlers
|
|||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pquerna/otp"
|
||||
"github.com/pquerna/otp/totp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
"github.com/authelia/authelia/v4/internal/session"
|
||||
)
|
||||
|
||||
|
@ -37,11 +39,15 @@ var SecondFactorTOTPIdentityStart = middlewares.IdentityVerificationStart(middle
|
|||
})
|
||||
|
||||
func secondFactorTOTPIdentityFinish(ctx *middlewares.AutheliaCtx, username string) {
|
||||
algorithm := otp.AlgorithmSHA1
|
||||
|
||||
key, err := totp.Generate(totp.GenerateOpts{
|
||||
Issuer: ctx.Configuration.TOTP.Issuer,
|
||||
AccountName: username,
|
||||
SecretSize: 32,
|
||||
Period: uint(ctx.Configuration.TOTP.Period),
|
||||
SecretSize: 32,
|
||||
Digits: otp.Digits(6),
|
||||
Algorithm: algorithm,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
@ -49,7 +55,15 @@ func secondFactorTOTPIdentityFinish(ctx *middlewares.AutheliaCtx, username strin
|
|||
return
|
||||
}
|
||||
|
||||
err = ctx.Providers.StorageProvider.SaveTOTPSecret(username, key.Secret())
|
||||
config := models.TOTPConfiguration{
|
||||
Username: username,
|
||||
Algorithm: otpAlgoToString(algorithm),
|
||||
Digits: 6,
|
||||
Secret: key.Secret(),
|
||||
Period: key.Period(),
|
||||
}
|
||||
|
||||
err = ctx.Providers.StorageProvider.SaveTOTPConfiguration(ctx, config)
|
||||
if err != nil {
|
||||
ctx.Error(fmt.Errorf("unable to save TOTP secret in DB: %s", err), messageUnableToRegisterOneTimePassword)
|
||||
return
|
||||
|
|
|
@ -57,11 +57,11 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissi
|
|||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
FindIdentityVerificationToken(gomock.Eq(token)).
|
||||
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
|
||||
Return(true, nil)
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
RemoveIdentityVerificationToken(gomock.Eq(token)).
|
||||
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
|
||||
Return(nil)
|
||||
|
||||
SecondFactorU2FIdentityFinish(s.mock.Ctx)
|
||||
|
@ -77,11 +77,11 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissin
|
|||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
FindIdentityVerificationToken(gomock.Eq(token)).
|
||||
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
|
||||
Return(true, nil)
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
RemoveIdentityVerificationToken(gomock.Eq(token)).
|
||||
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
|
||||
Return(nil)
|
||||
|
||||
SecondFactorU2FIdentityFinish(s.mock.Ctx)
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/tstranex/u2f"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
)
|
||||
|
||||
// SecondFactorU2FRegister handler validating the client has successfully validated the challenge
|
||||
|
@ -45,7 +46,12 @@ func SecondFactorU2FRegister(ctx *middlewares.AutheliaCtx) {
|
|||
ctx.Logger.Debugf("Register U2F device for user %s", userSession.Username)
|
||||
|
||||
publicKey := elliptic.Marshal(elliptic.P256(), registration.PubKey.X, registration.PubKey.Y)
|
||||
err = ctx.Providers.StorageProvider.SaveU2FDeviceHandle(userSession.Username, registration.KeyHandle, publicKey)
|
||||
|
||||
err = ctx.Providers.StorageProvider.SaveU2FDevice(ctx, models.U2FDevice{
|
||||
Username: userSession.Username,
|
||||
KeyHandle: registration.KeyHandle,
|
||||
PublicKey: publicKey},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
ctx.Error(fmt.Errorf("unable to register U2F device for user %s: %v", userSession.Username, err), messageUnableToRegisterSecurityKey)
|
||||
|
|
|
@ -19,13 +19,13 @@ func SecondFactorTOTPPost(totpVerifier TOTPVerifier) middlewares.RequestHandler
|
|||
|
||||
userSession := ctx.GetSession()
|
||||
|
||||
secret, err := ctx.Providers.StorageProvider.LoadTOTPSecret(userSession.Username)
|
||||
config, err := ctx.Providers.StorageProvider.LoadTOTPConfiguration(ctx, userSession.Username)
|
||||
if err != nil {
|
||||
handleAuthenticationUnauthorized(ctx, fmt.Errorf("unable to load TOTP secret: %s", err), messageMFAValidationFailed)
|
||||
return
|
||||
}
|
||||
|
||||
isValid, err := totpVerifier.Verify(requestBody.Token, secret)
|
||||
isValid, err := totpVerifier.Verify(config, requestBody.Token)
|
||||
if err != nil {
|
||||
handleAuthenticationUnauthorized(ctx, fmt.Errorf("error occurred during OTP validation for user %s: %s", userSession.Username, err), messageMFAValidationFailed)
|
||||
return
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/tstranex/u2f"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/mocks"
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
"github.com/authelia/authelia/v4/internal/session"
|
||||
)
|
||||
|
||||
|
@ -37,12 +38,14 @@ func (s *HandlerSignTOTPSuite) TearDownTest() {
|
|||
func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() {
|
||||
verifier := NewMockTOTPVerifier(s.mock.Ctrl)
|
||||
|
||||
config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: "secret", Period: 30, Algorithm: "SHA1"}
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
LoadTOTPSecret(gomock.Any()).
|
||||
Return("secret", nil)
|
||||
LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()).
|
||||
Return(&config, nil)
|
||||
|
||||
verifier.EXPECT().
|
||||
Verify(gomock.Eq("abc"), gomock.Eq("secret")).
|
||||
Verify(gomock.Eq(&config), gomock.Eq("abc")).
|
||||
Return(true, nil)
|
||||
|
||||
s.mock.Ctx.Configuration.DefaultRedirectionURL = testRedirectionURL
|
||||
|
@ -62,12 +65,14 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() {
|
|||
func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() {
|
||||
verifier := NewMockTOTPVerifier(s.mock.Ctrl)
|
||||
|
||||
config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: "secret", Period: 30, Algorithm: "SHA1"}
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
LoadTOTPSecret(gomock.Any()).
|
||||
Return("secret", nil)
|
||||
LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()).
|
||||
Return(&config, nil)
|
||||
|
||||
verifier.EXPECT().
|
||||
Verify(gomock.Eq("abc"), gomock.Eq("secret")).
|
||||
Verify(gomock.Eq(&config), gomock.Eq("abc")).
|
||||
Return(true, nil)
|
||||
|
||||
bodyBytes, err := json.Marshal(signTOTPRequestBody{
|
||||
|
@ -83,12 +88,14 @@ func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() {
|
|||
func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToSafeTargetURL() {
|
||||
verifier := NewMockTOTPVerifier(s.mock.Ctrl)
|
||||
|
||||
config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: "secret", Period: 30, Algorithm: "SHA1"}
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
LoadTOTPSecret(gomock.Any()).
|
||||
Return("secret", nil)
|
||||
LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()).
|
||||
Return(&config, nil)
|
||||
|
||||
verifier.EXPECT().
|
||||
Verify(gomock.Eq("abc"), gomock.Eq("secret")).
|
||||
Verify(gomock.Eq(&config), gomock.Eq("abc")).
|
||||
Return(true, nil)
|
||||
|
||||
bodyBytes, err := json.Marshal(signTOTPRequestBody{
|
||||
|
@ -108,11 +115,11 @@ func (s *HandlerSignTOTPSuite) TestShouldNotRedirectToUnsafeURL() {
|
|||
verifier := NewMockTOTPVerifier(s.mock.Ctrl)
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
LoadTOTPSecret(gomock.Any()).
|
||||
Return("secret", nil)
|
||||
LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()).
|
||||
Return(&models.TOTPConfiguration{Secret: "secret"}, nil)
|
||||
|
||||
verifier.EXPECT().
|
||||
Verify(gomock.Eq("abc"), gomock.Eq("secret")).
|
||||
Verify(gomock.Eq(&models.TOTPConfiguration{Secret: "secret"}), gomock.Eq("abc")).
|
||||
Return(true, nil)
|
||||
|
||||
bodyBytes, err := json.Marshal(signTOTPRequestBody{
|
||||
|
@ -129,12 +136,14 @@ func (s *HandlerSignTOTPSuite) TestShouldNotRedirectToUnsafeURL() {
|
|||
func (s *HandlerSignTOTPSuite) TestShouldRegenerateSessionForPreventingSessionFixation() {
|
||||
verifier := NewMockTOTPVerifier(s.mock.Ctrl)
|
||||
|
||||
config := models.TOTPConfiguration{ID: 1, Username: "john", Digits: 6, Secret: "secret", Period: 30, Algorithm: "SHA1"}
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
LoadTOTPSecret(gomock.Any()).
|
||||
Return("secret", nil)
|
||||
LoadTOTPConfiguration(s.mock.Ctx, gomock.Any()).
|
||||
Return(&config, nil)
|
||||
|
||||
verifier.EXPECT().
|
||||
Verify(gomock.Eq("abc"), gomock.Eq("secret")).
|
||||
Verify(gomock.Eq(&config), gomock.Eq("abc")).
|
||||
Return(true, nil)
|
||||
|
||||
bodyBytes, err := json.Marshal(signTOTPRequestBody{
|
||||
|
|
|
@ -34,7 +34,7 @@ func SecondFactorU2FSignGet(ctx *middlewares.AutheliaCtx) {
|
|||
}
|
||||
|
||||
userSession := ctx.GetSession()
|
||||
keyHandleBytes, publicKeyBytes, err := ctx.Providers.StorageProvider.LoadU2FDeviceHandle(userSession.Username)
|
||||
device, err := ctx.Providers.StorageProvider.LoadU2FDevice(ctx, userSession.Username)
|
||||
|
||||
if err != nil {
|
||||
if err == storage.ErrNoU2FDeviceHandle {
|
||||
|
@ -48,16 +48,16 @@ func SecondFactorU2FSignGet(ctx *middlewares.AutheliaCtx) {
|
|||
}
|
||||
|
||||
var registration u2f.Registration
|
||||
registration.KeyHandle = keyHandleBytes
|
||||
x, y := elliptic.Unmarshal(elliptic.P256(), publicKeyBytes)
|
||||
registration.KeyHandle = device.KeyHandle
|
||||
x, y := elliptic.Unmarshal(elliptic.P256(), device.PublicKey)
|
||||
registration.PubKey.Curve = elliptic.P256()
|
||||
registration.PubKey.X = x
|
||||
registration.PubKey.Y = y
|
||||
|
||||
// Save the challenge and registration for use in next request
|
||||
userSession.U2FRegistration = &session.U2FRegistration{
|
||||
KeyHandle: keyHandleBytes,
|
||||
PublicKey: publicKeyBytes,
|
||||
KeyHandle: device.KeyHandle,
|
||||
PublicKey: device.PublicKey,
|
||||
}
|
||||
userSession.U2FChallenge = challenge
|
||||
err = ctx.SaveSession(userSession)
|
||||
|
|
|
@ -3,97 +3,25 @@ package handlers
|
|||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
"github.com/authelia/authelia/v4/internal/middlewares"
|
||||
"github.com/authelia/authelia/v4/internal/storage"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
func loadInfo(username string, storageProvider storage.Provider, userInfo *UserInfo, logger *logrus.Entry) []error {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(3)
|
||||
|
||||
errors := make([]error, 0)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
method, err := storageProvider.LoadPreferred2FAMethod(username)
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
logger.Error(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if method == "" {
|
||||
userInfo.Method = authentication.PossibleMethods[0]
|
||||
} else {
|
||||
userInfo.Method = method
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
_, _, err := storageProvider.LoadU2FDeviceHandle(username)
|
||||
if err != nil {
|
||||
if err == storage.ErrNoU2FDeviceHandle {
|
||||
return
|
||||
}
|
||||
|
||||
errors = append(errors, err)
|
||||
logger.Error(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
userInfo.HasU2F = true
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
_, err := storageProvider.LoadTOTPSecret(username)
|
||||
if err != nil {
|
||||
if err == storage.ErrNoTOTPSecret {
|
||||
return
|
||||
}
|
||||
|
||||
errors = append(errors, err)
|
||||
logger.Error(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
userInfo.HasTOTP = true
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// UserInfoGet get the info related to the user identified by the session.
|
||||
func UserInfoGet(ctx *middlewares.AutheliaCtx) {
|
||||
userSession := ctx.GetSession()
|
||||
|
||||
userInfo := UserInfo{}
|
||||
errors := loadInfo(userSession.Username, ctx.Providers.StorageProvider, &userInfo, ctx.Logger)
|
||||
|
||||
if len(errors) > 0 {
|
||||
ctx.Error(fmt.Errorf("unable to load user information"), messageOperationFailed)
|
||||
userInfo, err := ctx.Providers.StorageProvider.LoadUserInfo(ctx, userSession.Username)
|
||||
if err != nil {
|
||||
ctx.Error(fmt.Errorf("unable to load user information: %v", err), messageOperationFailed)
|
||||
return
|
||||
}
|
||||
|
||||
userInfo.DisplayName = userSession.DisplayName
|
||||
|
||||
err := ctx.SetJSONBody(userInfo)
|
||||
err = ctx.SetJSONBody(userInfo)
|
||||
if err != nil {
|
||||
ctx.Logger.Errorf("Unable to set user info response in body: %s", err)
|
||||
}
|
||||
|
@ -121,7 +49,7 @@ func MethodPreferencePost(ctx *middlewares.AutheliaCtx) {
|
|||
|
||||
userSession := ctx.GetSession()
|
||||
ctx.Logger.Debugf("Save new preferred 2FA method of user %s to %s", userSession.Username, bodyJSON.Method)
|
||||
err = ctx.Providers.StorageProvider.SavePreferred2FAMethod(userSession.Username, bodyJSON.Method)
|
||||
err = ctx.Providers.StorageProvider.SavePreferred2FAMethod(ctx, userSession.Username, bodyJSON.Method)
|
||||
|
||||
if err != nil {
|
||||
ctx.Error(fmt.Errorf("unable to save new preferred 2FA method: %s", err), messageOperationFailed)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
|
@ -11,7 +13,7 @@ import (
|
|||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/mocks"
|
||||
"github.com/authelia/authelia/v4/internal/storage"
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
)
|
||||
|
||||
type FetchSuite struct {
|
||||
|
@ -33,62 +35,59 @@ func (s *FetchSuite) TearDownTest() {
|
|||
s.mock.Close()
|
||||
}
|
||||
|
||||
func setPreferencesExpectations(preferences UserInfo, provider *storage.MockProvider) {
|
||||
provider.
|
||||
EXPECT().
|
||||
LoadPreferred2FAMethod(gomock.Eq("john")).
|
||||
Return(preferences.Method, nil)
|
||||
|
||||
if preferences.HasU2F {
|
||||
u2fData := []byte("abc")
|
||||
provider.
|
||||
EXPECT().
|
||||
LoadU2FDeviceHandle(gomock.Eq("john")).
|
||||
Return(u2fData, u2fData, nil)
|
||||
} else {
|
||||
provider.
|
||||
EXPECT().
|
||||
LoadU2FDeviceHandle(gomock.Eq("john")).
|
||||
Return(nil, nil, storage.ErrNoU2FDeviceHandle)
|
||||
}
|
||||
|
||||
if preferences.HasTOTP {
|
||||
totpSecret := "secret"
|
||||
provider.
|
||||
EXPECT().
|
||||
LoadTOTPSecret(gomock.Eq("john")).
|
||||
Return(totpSecret, nil)
|
||||
} else {
|
||||
provider.
|
||||
EXPECT().
|
||||
LoadTOTPSecret(gomock.Eq("john")).
|
||||
Return("", storage.ErrNoTOTPSecret)
|
||||
}
|
||||
type expectedResponse struct {
|
||||
db models.UserInfo
|
||||
api *models.UserInfo
|
||||
err error
|
||||
}
|
||||
|
||||
func TestMethodSetToU2F(t *testing.T) {
|
||||
table := []UserInfo{
|
||||
expectedResponses := []expectedResponse{
|
||||
{
|
||||
Method: "totp",
|
||||
db: models.UserInfo{
|
||||
Method: "totp",
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
Method: "u2f",
|
||||
HasU2F: true,
|
||||
HasTOTP: true,
|
||||
db: models.UserInfo{
|
||||
Method: "u2f",
|
||||
HasU2F: true,
|
||||
HasTOTP: true,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
Method: "u2f",
|
||||
HasU2F: true,
|
||||
HasTOTP: false,
|
||||
db: models.UserInfo{
|
||||
Method: "u2f",
|
||||
HasU2F: true,
|
||||
HasTOTP: false,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
Method: "mobile_push",
|
||||
HasU2F: false,
|
||||
HasTOTP: false,
|
||||
db: models.UserInfo{
|
||||
Method: "mobile_push",
|
||||
HasU2F: false,
|
||||
HasTOTP: false,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
db: models.UserInfo{},
|
||||
err: sql.ErrNoRows,
|
||||
},
|
||||
{
|
||||
db: models.UserInfo{},
|
||||
err: errors.New("invalid thing"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, expectedPreferences := range table {
|
||||
for _, resp := range expectedResponses {
|
||||
if resp.api == nil {
|
||||
resp.api = &resp.db
|
||||
}
|
||||
|
||||
mock := mocks.NewMockAutheliaCtx(t)
|
||||
// Set the initial user session.
|
||||
userSession := mock.Ctx.GetSession()
|
||||
|
@ -97,64 +96,57 @@ func TestMethodSetToU2F(t *testing.T) {
|
|||
err := mock.Ctx.SaveSession(userSession)
|
||||
require.NoError(t, err)
|
||||
|
||||
setPreferencesExpectations(expectedPreferences, mock.StorageProviderMock)
|
||||
mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
LoadUserInfo(mock.Ctx, gomock.Eq("john")).
|
||||
Return(resp.db, resp.err)
|
||||
|
||||
UserInfoGet(mock.Ctx)
|
||||
|
||||
actualPreferences := UserInfo{}
|
||||
mock.GetResponseData(t, &actualPreferences)
|
||||
if resp.err == nil {
|
||||
t.Run("expected status code", func(t *testing.T) {
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
})
|
||||
|
||||
t.Run("expected method", func(t *testing.T) {
|
||||
assert.Equal(t, expectedPreferences.Method, actualPreferences.Method)
|
||||
})
|
||||
actualPreferences := models.UserInfo{}
|
||||
|
||||
t.Run("registered u2f", func(t *testing.T) {
|
||||
assert.Equal(t, expectedPreferences.HasU2F, actualPreferences.HasU2F)
|
||||
})
|
||||
mock.GetResponseData(t, &actualPreferences)
|
||||
|
||||
t.Run("expected method", func(t *testing.T) {
|
||||
assert.Equal(t, resp.api.Method, actualPreferences.Method)
|
||||
})
|
||||
|
||||
t.Run("registered u2f", func(t *testing.T) {
|
||||
assert.Equal(t, resp.api.HasU2F, actualPreferences.HasU2F)
|
||||
})
|
||||
|
||||
t.Run("registered totp", func(t *testing.T) {
|
||||
assert.Equal(t, resp.api.HasTOTP, actualPreferences.HasTOTP)
|
||||
})
|
||||
} else {
|
||||
t.Run("expected status code", func(t *testing.T) {
|
||||
assert.Equal(t, 200, mock.Ctx.Response.StatusCode())
|
||||
})
|
||||
|
||||
errResponse := mock.GetResponseError(t)
|
||||
|
||||
assert.Equal(t, "KO", errResponse.Status)
|
||||
assert.Equal(t, "Operation failed.", errResponse.Message)
|
||||
}
|
||||
|
||||
t.Run("registered totp", func(t *testing.T) {
|
||||
assert.Equal(t, expectedPreferences.HasTOTP, actualPreferences.HasTOTP)
|
||||
})
|
||||
mock.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *FetchSuite) TestShouldGetDefaultPreferenceIfNotInDB() {
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
LoadPreferred2FAMethod(gomock.Eq("john")).
|
||||
Return("", nil)
|
||||
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
LoadU2FDeviceHandle(gomock.Eq("john")).
|
||||
Return(nil, nil, storage.ErrNoU2FDeviceHandle)
|
||||
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
LoadTOTPSecret(gomock.Eq("john")).
|
||||
Return("", storage.ErrNoTOTPSecret)
|
||||
|
||||
UserInfoGet(s.mock.Ctx)
|
||||
s.mock.Assert200OK(s.T(), UserInfo{Method: "totp"})
|
||||
}
|
||||
|
||||
func (s *FetchSuite) TestShouldReturnError500WhenStorageFailsToLoad() {
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
LoadPreferred2FAMethod(gomock.Eq("john")).
|
||||
Return("", fmt.Errorf("Failure"))
|
||||
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
LoadU2FDeviceHandle(gomock.Eq("john"))
|
||||
|
||||
s.mock.StorageProviderMock.
|
||||
EXPECT().
|
||||
LoadTOTPSecret(gomock.Eq("john"))
|
||||
LoadUserInfo(s.mock.Ctx, gomock.Eq("john")).
|
||||
Return(models.UserInfo{}, fmt.Errorf("failure"))
|
||||
|
||||
UserInfoGet(s.mock.Ctx)
|
||||
|
||||
s.mock.Assert200KO(s.T(), "Operation failed.")
|
||||
assert.Equal(s.T(), "unable to load user information", s.mock.Hook.LastEntry().Message)
|
||||
assert.Equal(s.T(), "unable to load user information: failure", s.mock.Hook.LastEntry().Message)
|
||||
assert.Equal(s.T(), logrus.ErrorLevel, s.mock.Hook.LastEntry().Level)
|
||||
}
|
||||
|
||||
|
@ -220,7 +212,7 @@ func (s *SaveSuite) TestShouldReturnError500WhenBadMethodProvided() {
|
|||
func (s *SaveSuite) TestShouldReturnError500WhenDatabaseFailsToSave() {
|
||||
s.mock.Ctx.Request.SetBody([]byte("{\"method\":\"u2f\"}"))
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
SavePreferred2FAMethod(gomock.Eq("john"), gomock.Eq("u2f")).
|
||||
SavePreferred2FAMethod(s.mock.Ctx, gomock.Eq("john"), gomock.Eq("u2f")).
|
||||
Return(fmt.Errorf("Failure"))
|
||||
|
||||
MethodPreferencePost(s.mock.Ctx)
|
||||
|
@ -233,7 +225,7 @@ func (s *SaveSuite) TestShouldReturnError500WhenDatabaseFailsToSave() {
|
|||
func (s *SaveSuite) TestShouldReturn200WhenMethodIsSuccessfullySaved() {
|
||||
s.mock.Ctx.Request.SetBody([]byte("{\"method\":\"u2f\"}"))
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
SavePreferred2FAMethod(gomock.Eq("john"), gomock.Eq("u2f")).
|
||||
SavePreferred2FAMethod(s.mock.Ctx, gomock.Eq("john"), gomock.Eq("u2f")).
|
||||
Return(nil)
|
||||
|
||||
MethodPreferencePost(s.mock.Ctx)
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/pquerna/otp"
|
||||
"github.com/pquerna/otp/totp"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
)
|
||||
|
||||
// TOTPVerifier is the interface for verifying TOTPs.
|
||||
type TOTPVerifier interface {
|
||||
Verify(token, secret string) (bool, error)
|
||||
Verify(config *models.TOTPConfiguration, token string) (bool, error)
|
||||
}
|
||||
|
||||
// TOTPVerifierImpl the production implementation for TOTP verification.
|
||||
|
@ -19,13 +22,43 @@ type TOTPVerifierImpl struct {
|
|||
}
|
||||
|
||||
// Verify verifies TOTPs.
|
||||
func (tv *TOTPVerifierImpl) Verify(token, secret string) (bool, error) {
|
||||
opts := totp.ValidateOpts{
|
||||
Period: tv.Period,
|
||||
Skew: tv.Skew,
|
||||
Digits: otp.DigitsSix,
|
||||
Algorithm: otp.AlgorithmSHA1,
|
||||
func (tv *TOTPVerifierImpl) Verify(config *models.TOTPConfiguration, token string) (bool, error) {
|
||||
if config == nil {
|
||||
return false, errors.New("config not provided")
|
||||
}
|
||||
|
||||
return totp.ValidateCustom(token, secret, time.Now().UTC(), opts)
|
||||
opts := totp.ValidateOpts{
|
||||
Period: uint(config.Period),
|
||||
Skew: tv.Skew,
|
||||
Digits: otp.Digits(config.Digits),
|
||||
Algorithm: otpStringToAlgo(config.Algorithm),
|
||||
}
|
||||
|
||||
return totp.ValidateCustom(token, config.Secret, time.Now().UTC(), opts)
|
||||
}
|
||||
|
||||
func otpAlgoToString(algorithm otp.Algorithm) (out string) {
|
||||
switch algorithm {
|
||||
case otp.AlgorithmSHA1:
|
||||
return totpAlgoSHA1
|
||||
case otp.AlgorithmSHA256:
|
||||
return totpAlgoSHA256
|
||||
case otp.AlgorithmSHA512:
|
||||
return totpAlgoSHA512
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func otpStringToAlgo(in string) (algorithm otp.Algorithm) {
|
||||
switch in {
|
||||
case totpAlgoSHA1:
|
||||
return otp.AlgorithmSHA1
|
||||
case totpAlgoSHA256:
|
||||
return otp.AlgorithmSHA256
|
||||
case totpAlgoSHA512:
|
||||
return otp.AlgorithmSHA512
|
||||
default:
|
||||
return otp.AlgorithmSHA1
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,45 +5,47 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
"reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
)
|
||||
|
||||
// MockTOTPVerifier is a mock of TOTPVerifier interface
|
||||
// MockTOTPVerifier is a mock of TOTPVerifier interface.
|
||||
type MockTOTPVerifier struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockTOTPVerifierMockRecorder
|
||||
}
|
||||
|
||||
// MockTOTPVerifierMockRecorder is the mock recorder for MockTOTPVerifier
|
||||
// MockTOTPVerifierMockRecorder is the mock recorder for MockTOTPVerifier.
|
||||
type MockTOTPVerifierMockRecorder struct {
|
||||
mock *MockTOTPVerifier
|
||||
}
|
||||
|
||||
// NewMockTOTPVerifier creates a new mock instance
|
||||
// NewMockTOTPVerifier creates a new mock instance.
|
||||
func NewMockTOTPVerifier(ctrl *gomock.Controller) *MockTOTPVerifier {
|
||||
mock := &MockTOTPVerifier{ctrl: ctrl}
|
||||
mock.recorder = &MockTOTPVerifierMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockTOTPVerifier) EXPECT() *MockTOTPVerifierMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Verify mocks base method
|
||||
func (m *MockTOTPVerifier) Verify(token, secret string) (bool, error) {
|
||||
// Verify mocks base method.
|
||||
func (m *MockTOTPVerifier) Verify(arg0 *models.TOTPConfiguration, arg1 string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Verify", token, secret)
|
||||
ret := m.ctrl.Call(m, "Verify", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Verify indicates an expected call of Verify
|
||||
func (mr *MockTOTPVerifierMockRecorder) Verify(token, secret interface{}) *gomock.Call {
|
||||
// Verify indicates an expected call of Verify.
|
||||
func (mr *MockTOTPVerifierMockRecorder) Verify(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockTOTPVerifier)(nil).Verify), token, secret)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Verify", reflect.TypeOf((*MockTOTPVerifier)(nil).Verify), arg0, arg1)
|
||||
}
|
||||
|
|
|
@ -11,21 +11,6 @@ type MethodList = []string
|
|||
|
||||
type authorizationMatching int
|
||||
|
||||
// UserInfo is the model of user info and second factor preferences.
|
||||
type UserInfo struct {
|
||||
// The users display name.
|
||||
DisplayName string `json:"display_name"`
|
||||
|
||||
// The preferred 2FA method.
|
||||
Method string `json:"method" valid:"required"`
|
||||
|
||||
// True if a security key has been registered.
|
||||
HasU2F bool `json:"has_u2f" valid:"required"`
|
||||
|
||||
// True if a TOTP device has been registered.
|
||||
HasTOTP bool `json:"has_totp" valid:"required"`
|
||||
}
|
||||
|
||||
// signTOTPRequestBody model of the request body received by TOTP authentication endpoint.
|
||||
type signTOTPRequestBody struct {
|
||||
Token string `json:"token" valid:"required"`
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
"github.com/authelia/authelia/v4/internal/templates"
|
||||
)
|
||||
|
||||
|
@ -47,7 +48,9 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle
|
|||
return
|
||||
}
|
||||
|
||||
err = ctx.Providers.StorageProvider.SaveIdentityVerificationToken(ss)
|
||||
err = ctx.Providers.StorageProvider.SaveIdentityVerification(ctx, models.IdentityVerification{
|
||||
Token: ss,
|
||||
})
|
||||
if err != nil {
|
||||
ctx.Error(err, messageOperationFailed)
|
||||
return
|
||||
|
@ -128,7 +131,7 @@ func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(c
|
|||
return
|
||||
}
|
||||
|
||||
found, err := ctx.Providers.StorageProvider.FindIdentityVerificationToken(finishBody.Token)
|
||||
found, err := ctx.Providers.StorageProvider.FindIdentityVerification(ctx, finishBody.Token)
|
||||
|
||||
if err != nil {
|
||||
ctx.Error(err, messageOperationFailed)
|
||||
|
@ -185,7 +188,7 @@ func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(c
|
|||
}
|
||||
|
||||
// TODO(c.michaud): find a way to garbage collect unused tokens.
|
||||
err = ctx.Providers.StorageProvider.RemoveIdentityVerificationToken(finishBody.Token)
|
||||
err = ctx.Providers.StorageProvider.RemoveIdentityVerification(ctx, finishBody.Token)
|
||||
if err != nil {
|
||||
ctx.Error(err, messageOperationFailed)
|
||||
return
|
||||
|
|
|
@ -55,7 +55,7 @@ func TestShouldFailIfJWTCannotBeSaved(t *testing.T) {
|
|||
mock.Ctx.Configuration.JWTSecret = testJWTSecret
|
||||
|
||||
mock.StorageProviderMock.EXPECT().
|
||||
SaveIdentityVerificationToken(gomock.Any()).
|
||||
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
||||
Return(fmt.Errorf("cannot save"))
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
|
@ -74,7 +74,7 @@ func TestShouldFailSendingAnEmail(t *testing.T) {
|
|||
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
|
||||
|
||||
mock.StorageProviderMock.EXPECT().
|
||||
SaveIdentityVerificationToken(gomock.Any()).
|
||||
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
||||
Return(nil)
|
||||
|
||||
mock.NotifierMock.EXPECT().
|
||||
|
@ -96,7 +96,7 @@ func TestShouldFailWhenXForwardedProtoHeaderIsMissing(t *testing.T) {
|
|||
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
|
||||
|
||||
mock.StorageProviderMock.EXPECT().
|
||||
SaveIdentityVerificationToken(gomock.Any()).
|
||||
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
||||
Return(nil)
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
|
@ -114,7 +114,7 @@ func TestShouldFailWhenXForwardedHostHeaderIsMissing(t *testing.T) {
|
|||
mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
|
||||
|
||||
mock.StorageProviderMock.EXPECT().
|
||||
SaveIdentityVerificationToken(gomock.Any()).
|
||||
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
||||
Return(nil)
|
||||
|
||||
args := newArgs(defaultRetriever)
|
||||
|
@ -132,7 +132,7 @@ func TestShouldSucceedIdentityVerificationStartProcess(t *testing.T) {
|
|||
mock.Ctx.Request.Header.Add("X-Forwarded-Host", "host")
|
||||
|
||||
mock.StorageProviderMock.EXPECT().
|
||||
SaveIdentityVerificationToken(gomock.Any()).
|
||||
SaveIdentityVerification(mock.Ctx, gomock.Any()).
|
||||
Return(nil)
|
||||
|
||||
mock.NotifierMock.EXPECT().
|
||||
|
@ -209,7 +209,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotFoundInDB(
|
|||
s.mock.Ctx.Request.SetBodyString("{\"token\":\"abc\"}")
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
FindIdentityVerificationToken(gomock.Eq("abc")).
|
||||
FindIdentityVerification(s.mock.Ctx, gomock.Eq("abc")).
|
||||
Return(false, nil)
|
||||
|
||||
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
||||
|
@ -222,7 +222,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsInvalid() {
|
|||
s.mock.Ctx.Request.SetBodyString("{\"token\":\"abc\"}")
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
FindIdentityVerificationToken(gomock.Eq("abc")).
|
||||
FindIdentityVerification(s.mock.Ctx, gomock.Eq("abc")).
|
||||
Return(true, nil)
|
||||
|
||||
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
||||
|
@ -238,7 +238,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() {
|
|||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
FindIdentityVerificationToken(gomock.Eq(token)).
|
||||
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
|
||||
Return(true, nil)
|
||||
|
||||
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
||||
|
@ -253,7 +253,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() {
|
|||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
FindIdentityVerificationToken(gomock.Eq(token)).
|
||||
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
|
||||
Return(true, nil)
|
||||
|
||||
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
||||
|
@ -268,7 +268,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() {
|
|||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
FindIdentityVerificationToken(gomock.Eq(token)).
|
||||
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
|
||||
Return(true, nil)
|
||||
|
||||
args := newFinishArgs()
|
||||
|
@ -285,11 +285,11 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemoved
|
|||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
FindIdentityVerificationToken(gomock.Eq(token)).
|
||||
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
|
||||
Return(true, nil)
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
RemoveIdentityVerificationToken(gomock.Eq(token)).
|
||||
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
|
||||
Return(fmt.Errorf("cannot remove"))
|
||||
|
||||
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
||||
|
@ -304,11 +304,11 @@ func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete(
|
|||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
FindIdentityVerificationToken(gomock.Eq(token)).
|
||||
FindIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
|
||||
Return(true, nil)
|
||||
|
||||
s.mock.StorageProviderMock.EXPECT().
|
||||
RemoveIdentityVerificationToken(gomock.Eq(token)).
|
||||
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(token)).
|
||||
Return(nil)
|
||||
|
||||
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
||||
|
|
|
@ -28,11 +28,6 @@ type AutheliaCtx struct {
|
|||
Clock utils.Clock
|
||||
}
|
||||
|
||||
// ProviderWithStartupCheck represents a provider that has a startup check.
|
||||
type ProviderWithStartupCheck interface {
|
||||
StartupCheck(logger *logrus.Logger) (err error)
|
||||
}
|
||||
|
||||
// Providers contain all provider provided to Authelia.
|
||||
type Providers struct {
|
||||
Authorizer *authorization.Authorizer
|
||||
|
|
|
@ -183,3 +183,11 @@ func (m *MockAutheliaCtx) GetResponseData(t *testing.T, data interface{}) {
|
|||
err := json.Unmarshal(m.Ctx.Response.Body(), &okResponse)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// GetResponseError retrieves an error response from the service.
|
||||
func (m *MockAutheliaCtx) GetResponseError(t *testing.T) (errResponse middlewares.ErrorResponse) {
|
||||
err := json.Unmarshal(m.Ctx.Response.Body(), &errResponse)
|
||||
require.NoError(t, err)
|
||||
|
||||
return errResponse
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// MockNotifier is a mock of Notifier interface.
|
||||
|
@ -49,15 +48,15 @@ func (mr *MockNotifierMockRecorder) Send(arg0, arg1, arg2, arg3 interface{}) *go
|
|||
}
|
||||
|
||||
// StartupCheck mocks base method.
|
||||
func (m *MockNotifier) StartupCheck(arg0 *logrus.Logger) error {
|
||||
func (m *MockNotifier) StartupCheck() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StartupCheck", arg0)
|
||||
ret := m.ctrl.Call(m, "StartupCheck")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StartupCheck indicates an expected call of StartupCheck.
|
||||
func (mr *MockNotifierMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call {
|
||||
func (mr *MockNotifierMockRecorder) StartupCheck() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockNotifier)(nil).StartupCheck))
|
||||
}
|
||||
|
|
|
@ -5,12 +5,11 @@
|
|||
package mocks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
reflect "reflect"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/sirupsen/logrus"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
authentication "github.com/authelia/authelia/v4/internal/authentication"
|
||||
)
|
||||
|
||||
// MockUserProvider is a mock of UserProvider interface.
|
||||
|
@ -66,7 +65,21 @@ func (mr *MockUserProviderMockRecorder) GetDetails(arg0 interface{}) *gomock.Cal
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDetails", reflect.TypeOf((*MockUserProvider)(nil).GetDetails), arg0)
|
||||
}
|
||||
|
||||
// UpdatePassword mocks base method
|
||||
// StartupCheck mocks base method.
|
||||
func (m *MockUserProvider) StartupCheck() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StartupCheck")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StartupCheck indicates an expected call of StartupCheck.
|
||||
func (mr *MockUserProviderMockRecorder) StartupCheck() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockUserProvider)(nil).StartupCheck))
|
||||
}
|
||||
|
||||
// UpdatePassword mocks base method.
|
||||
func (m *MockUserProvider) UpdatePassword(arg0, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdatePassword", arg0, arg1)
|
||||
|
@ -79,17 +92,3 @@ func (mr *MockUserProviderMockRecorder) UpdatePassword(arg0, arg1 interface{}) *
|
|||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdatePassword", reflect.TypeOf((*MockUserProvider)(nil).UpdatePassword), arg0, arg1)
|
||||
}
|
||||
|
||||
// StartupCheck mocks base method.
|
||||
func (m *MockUserProvider) StartupCheck(arg0 *logrus.Logger) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StartupCheck", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StartupCheck indicates an expected call of StartupCheck.
|
||||
func (mr *MockUserProviderMockRecorder) StartupCheck(arg0 *logrus.Logger) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockUserProvider)(nil).StartupCheck), arg0)
|
||||
}
|
||||
|
|
17
internal/models/model_authentication_attempt.go
Normal file
17
internal/models/model_authentication_attempt.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuthenticationAttempt represents an authentication attempt row in the database.
|
||||
type AuthenticationAttempt struct {
|
||||
ID int `db:"id"`
|
||||
Time time.Time `db:"time"`
|
||||
Successful bool `db:"successful"`
|
||||
Username string `db:"username"`
|
||||
Type string `db:"auth_type"`
|
||||
RemoteIP IPAddress `db:"remote_ip"`
|
||||
RequestURI string `db:"request_uri"`
|
||||
RequestMethod string `db:"request_method"`
|
||||
}
|
12
internal/models/model_identity_verification.go
Normal file
12
internal/models/model_identity_verification.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// IdentityVerification represents an identity verification row in the database.
|
||||
type IdentityVerification struct {
|
||||
ID int `db:"id"`
|
||||
Created time.Time `db:"created"`
|
||||
Token string `db:"token"`
|
||||
}
|
14
internal/models/model_migration.go
Normal file
14
internal/models/model_migration.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Migration represents a migration row in the database.
|
||||
type Migration struct {
|
||||
ID int `db:"id"`
|
||||
Applied time.Time `db:"applied"`
|
||||
Before int `db:"version_before"`
|
||||
After int `db:"version_after"`
|
||||
Version string `db:"application_version"`
|
||||
}
|
11
internal/models/model_totp_configuration.go
Normal file
11
internal/models/model_totp_configuration.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package models
|
||||
|
||||
// TOTPConfiguration represents a users TOTP configuration row in the database.
|
||||
type TOTPConfiguration struct {
|
||||
ID int `db:"id"`
|
||||
Username string `db:"username"`
|
||||
Algorithm string `db:"algorithm"`
|
||||
Digits int `db:"digits"`
|
||||
Period uint64 `db:"totp_period"`
|
||||
Secret string `db:"secret"`
|
||||
}
|
10
internal/models/model_u2f_device.go
Normal file
10
internal/models/model_u2f_device.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package models
|
||||
|
||||
// U2FDevice represents a users U2F device row in the database.
|
||||
type U2FDevice struct {
|
||||
ID int `db:"id"`
|
||||
Username string `db:"username"`
|
||||
Description string `db:"description"`
|
||||
KeyHandle []byte `db:"key_handle"`
|
||||
PublicKey []byte `db:"public_key"`
|
||||
}
|
16
internal/models/model_userinfo.go
Normal file
16
internal/models/model_userinfo.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package models
|
||||
|
||||
// UserInfo represents the user information required by the web UI.
|
||||
type UserInfo struct {
|
||||
// The users display name.
|
||||
DisplayName string `db:"-" json:"display_name"`
|
||||
|
||||
// The preferred 2FA method.
|
||||
Method string `db:"second_factor_method" json:"method" valid:"required"`
|
||||
|
||||
// True if a security key has been registered.
|
||||
HasU2F bool `db:"has_u2f" json:"has_u2f" valid:"required"`
|
||||
|
||||
// True if a TOTP device has been registered.
|
||||
HasTOTP bool `db:"has_totp" json:"has_totp" valid:"required"`
|
||||
}
|
42
internal/models/type_ipaddress.go
Normal file
42
internal/models/type_ipaddress.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// IPAddress is a type specific for storage of a net.IP in the database.
|
||||
type IPAddress struct {
|
||||
*net.IP
|
||||
}
|
||||
|
||||
// Value is the IPAddress implementation of the databases/sql driver.Valuer.
|
||||
func (ip IPAddress) Value() (value driver.Value, err error) {
|
||||
if ip.IP == nil {
|
||||
return driver.Value(nil), nil
|
||||
}
|
||||
|
||||
return driver.Value(ip.IP.String()), nil
|
||||
}
|
||||
|
||||
// Scan is the IPAddress implementation of the sql.Scanner.
|
||||
func (ip *IPAddress) Scan(src interface{}) (err error) {
|
||||
if src == nil {
|
||||
ip.IP = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var value string
|
||||
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
value = v
|
||||
default:
|
||||
return fmt.Errorf("invalid type %T for IPAddress %v", src, src)
|
||||
}
|
||||
|
||||
*ip.IP = net.ParseIP(value)
|
||||
|
||||
return nil
|
||||
}
|
6
internal/models/type_startup_check.go
Normal file
6
internal/models/type_startup_check.go
Normal file
|
@ -0,0 +1,6 @@
|
|||
package models
|
||||
|
||||
// StartupCheck represents a provider that has a startup check.
|
||||
type StartupCheck interface {
|
||||
StartupCheck() (err error)
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
package models
|
||||
|
||||
import "time"
|
||||
|
||||
// AuthenticationAttempt represent an authentication attempt.
|
||||
type AuthenticationAttempt struct {
|
||||
// The user who tried to authenticate.
|
||||
Username string
|
||||
// Successful true if the attempt was successful.
|
||||
Successful bool
|
||||
// The time of the attempt.
|
||||
Time time.Time
|
||||
}
|
|
@ -7,8 +7,6 @@ import (
|
|||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
||||
|
@ -25,7 +23,7 @@ func NewFileNotifier(configuration schema.FileSystemNotifierConfiguration) *File
|
|||
}
|
||||
|
||||
// StartupCheck implements the startup check provider interface.
|
||||
func (n *FileNotifier) StartupCheck(_ *logrus.Logger) (err error) {
|
||||
func (n *FileNotifier) StartupCheck() (err error) {
|
||||
dir := filepath.Dir(n.path)
|
||||
if _, err := os.Stat(dir); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
package notification
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
)
|
||||
|
||||
// Notifier interface for sending the identity verification link.
|
||||
type Notifier interface {
|
||||
models.StartupCheck
|
||||
|
||||
Send(recipient, subject, body, htmlBody string) (err error)
|
||||
StartupCheck(logger *logrus.Logger) (err error)
|
||||
}
|
||||
|
|
|
@ -10,8 +10,6 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
|
@ -223,7 +221,7 @@ func (n *SMTPNotifier) cleanup() {
|
|||
}
|
||||
|
||||
// StartupCheck implements the startup check provider interface.
|
||||
func (n *SMTPNotifier) StartupCheck(_ *logrus.Logger) (err error) {
|
||||
func (n *SMTPNotifier) StartupCheck() (err error) {
|
||||
if err := n.dial(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -6,22 +6,24 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
// NewProvider instantiate a ntp provider given a configuration.
|
||||
func NewProvider(config *schema.NTPConfiguration) *Provider {
|
||||
return &Provider{config}
|
||||
return &Provider{
|
||||
config: config,
|
||||
log: logging.Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// StartupCheck implements the startup check provider interface.
|
||||
func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) {
|
||||
func (p *Provider) StartupCheck() (err error) {
|
||||
conn, err := net.Dial("udp", p.config.Address)
|
||||
if err != nil {
|
||||
logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
|
||||
p.log.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -29,7 +31,7 @@ func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) {
|
|||
defer conn.Close()
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
logger.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
|
||||
p.log.Warnf("Could not connect to NTP server to validate the system time is properly synchronized: %+v", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -42,7 +44,7 @@ func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) {
|
|||
req := &ntpPacket{LeapVersionMode: ntpLeapVersionClientMode(false, version)}
|
||||
|
||||
if err := binary.Write(conn, binary.BigEndian, req); err != nil {
|
||||
logger.Warnf("Could not write to the NTP server socket to validate the system time is properly synchronized: %+v", err)
|
||||
p.log.Warnf("Could not write to the NTP server socket to validate the system time is properly synchronized: %+v", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -52,7 +54,7 @@ func (p *Provider) StartupCheck(logger *logrus.Logger) (err error) {
|
|||
resp := &ntpPacket{}
|
||||
|
||||
if err := binary.Read(conn, binary.BigEndian, resp); err != nil {
|
||||
logger.Warnf("Could not read from the NTP server socket to validate the system time is properly synchronized: %+v", err)
|
||||
p.log.Warnf("Could not read from the NTP server socket to validate the system time is properly synchronized: %+v", err)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
"github.com/authelia/authelia/v4/internal/configuration/validator"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
)
|
||||
|
||||
func TestShouldCheckNTP(t *testing.T) {
|
||||
|
@ -22,5 +21,5 @@ func TestShouldCheckNTP(t *testing.T) {
|
|||
|
||||
ntp := NewProvider(&config)
|
||||
|
||||
assert.NoError(t, ntp.StartupCheck(logging.Logger()))
|
||||
assert.NoError(t, ntp.StartupCheck())
|
||||
}
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
package ntp
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
||||
// Provider type is the NTP provider.
|
||||
type Provider struct {
|
||||
config *schema.NTPConfiguration
|
||||
log *logrus.Logger
|
||||
}
|
||||
|
||||
type ntpVersion int
|
||||
|
|
|
@ -20,10 +20,7 @@ func NewOpenIDConnectProvider(configuration *schema.OpenIDConnectConfiguration)
|
|||
return provider, nil
|
||||
}
|
||||
|
||||
provider.Store, err = NewOpenIDConnectStore(configuration)
|
||||
if err != nil {
|
||||
return provider, err
|
||||
}
|
||||
provider.Store = NewOpenIDConnectStore(configuration)
|
||||
|
||||
composeConfiguration := &compose.Config{
|
||||
AccessTokenLifespan: configuration.AccessTokenLifespan,
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
// NewOpenIDConnectStore returns a new OpenIDConnectStore using the provided schema.OpenIDConnectConfiguration.
|
||||
func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (store *OpenIDConnectStore, err error) {
|
||||
func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (store *OpenIDConnectStore) {
|
||||
logger := logging.Logger()
|
||||
|
||||
store = &OpenIDConnectStore{
|
||||
|
@ -39,7 +39,7 @@ func NewOpenIDConnectStore(configuration *schema.OpenIDConnectConfiguration) (st
|
|||
store.clients[client.ID] = NewClient(client)
|
||||
}
|
||||
|
||||
return store, nil
|
||||
return store
|
||||
}
|
||||
|
||||
// GetClientPolicy retrieves the policy from the client with the matching provided id.
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
)
|
||||
|
||||
func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) {
|
||||
s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
|
||||
s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
|
||||
IssuerPrivateKey: exampleIssuerPrivateKey,
|
||||
Clients: []schema.OpenIDConnectClientConfiguration{
|
||||
{
|
||||
|
@ -32,8 +32,6 @@ func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) {
|
|||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
policyOne := s.GetClientPolicy("myclient")
|
||||
assert.Equal(t, authorization.OneFactor, policyOne)
|
||||
|
||||
|
@ -45,7 +43,7 @@ func TestOpenIDConnectStore_GetClientPolicy(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestOpenIDConnectStore_GetInternalClient(t *testing.T) {
|
||||
s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
|
||||
s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
|
||||
IssuerPrivateKey: exampleIssuerPrivateKey,
|
||||
Clients: []schema.OpenIDConnectClientConfiguration{
|
||||
{
|
||||
|
@ -58,8 +56,6 @@ func TestOpenIDConnectStore_GetInternalClient(t *testing.T) {
|
|||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
client, err := s.GetClient(context.Background(), "myinvalidclient")
|
||||
assert.EqualError(t, err, "not_found")
|
||||
assert.Nil(t, client)
|
||||
|
@ -78,13 +74,12 @@ func TestOpenIDConnectStore_GetInternalClient_ValidClient(t *testing.T) {
|
|||
Scopes: []string{"openid", "profile"},
|
||||
Secret: "mysecret",
|
||||
}
|
||||
s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
|
||||
|
||||
s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
|
||||
IssuerPrivateKey: exampleIssuerPrivateKey,
|
||||
Clients: []schema.OpenIDConnectClientConfiguration{c1},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
client, err := s.GetInternalClient(c1.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, client)
|
||||
|
@ -106,20 +101,19 @@ func TestOpenIDConnectStore_GetInternalClient_InvalidClient(t *testing.T) {
|
|||
Scopes: []string{"openid", "profile"},
|
||||
Secret: "mysecret",
|
||||
}
|
||||
s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
|
||||
|
||||
s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
|
||||
IssuerPrivateKey: exampleIssuerPrivateKey,
|
||||
Clients: []schema.OpenIDConnectClientConfiguration{c1},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
client, err := s.GetInternalClient("another-client")
|
||||
assert.Nil(t, client)
|
||||
assert.EqualError(t, err, "not_found")
|
||||
}
|
||||
|
||||
func TestOpenIDConnectStore_IsValidClientID(t *testing.T) {
|
||||
s, err := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
|
||||
s := NewOpenIDConnectStore(&schema.OpenIDConnectConfiguration{
|
||||
IssuerPrivateKey: exampleIssuerPrivateKey,
|
||||
Clients: []schema.OpenIDConnectClientConfiguration{
|
||||
{
|
||||
|
@ -132,8 +126,6 @@ func TestOpenIDConnectStore_IsValidClientID(t *testing.T) {
|
|||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
validClient := s.IsValidClientID("myclient")
|
||||
invalidClient := s.IsValidClientID("myinvalidclient")
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package regulation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
@ -11,7 +12,7 @@ import (
|
|||
)
|
||||
|
||||
// NewRegulator create a regulator instance.
|
||||
func NewRegulator(configuration *schema.RegulationConfiguration, provider storage.Provider, clock utils.Clock) *Regulator {
|
||||
func NewRegulator(configuration *schema.RegulationConfiguration, provider storage.RegulatorProvider, clock utils.Clock) *Regulator {
|
||||
regulator := &Regulator{storageProvider: provider}
|
||||
regulator.clock = clock
|
||||
|
||||
|
@ -40,30 +41,25 @@ func NewRegulator(configuration *schema.RegulationConfiguration, provider storag
|
|||
return regulator
|
||||
}
|
||||
|
||||
// Mark mark an authentication attempt.
|
||||
// Mark an authentication attempt.
|
||||
// We split Mark and Regulate in order to avoid timing attacks.
|
||||
func (r *Regulator) Mark(username string, successful bool) error {
|
||||
return r.storageProvider.AppendAuthenticationLog(models.AuthenticationAttempt{
|
||||
func (r *Regulator) Mark(ctx context.Context, username string, successful bool) error {
|
||||
return r.storageProvider.AppendAuthenticationLog(ctx, models.AuthenticationAttempt{
|
||||
Username: username,
|
||||
Successful: successful,
|
||||
Time: r.clock.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// Regulate regulate the authentication attempts for a given user.
|
||||
// This method returns ErrUserIsBanned if the user is banned along with the time until when
|
||||
// the user is banned.
|
||||
func (r *Regulator) Regulate(username string) (time.Time, error) {
|
||||
// Regulate the authentication attempts for a given user.
|
||||
// This method returns ErrUserIsBanned if the user is banned along with the time until when the user is banned.
|
||||
func (r *Regulator) Regulate(ctx context.Context, username string) (time.Time, error) {
|
||||
// If there is regulation configuration, no regulation applies.
|
||||
if !r.enabled {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
now := r.clock.Now()
|
||||
|
||||
// TODO(c.michaud): make sure FindTime < BanTime.
|
||||
attempts, err := r.storageProvider.LoadLatestAuthenticationLogs(username, now.Add(-r.banTime))
|
||||
|
||||
attempts, err := r.storageProvider.LoadAuthenticationLogs(ctx, username, r.clock.Now().Add(-r.banTime), 10, 0)
|
||||
if err != nil {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package regulation_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -18,6 +19,7 @@ import (
|
|||
type RegulatorSuite struct {
|
||||
suite.Suite
|
||||
|
||||
ctx context.Context
|
||||
ctrl *gomock.Controller
|
||||
storageMock *storage.MockProvider
|
||||
configuration schema.RegulationConfiguration
|
||||
|
@ -27,6 +29,7 @@ type RegulatorSuite struct {
|
|||
func (s *RegulatorSuite) SetupTest() {
|
||||
s.ctrl = gomock.NewController(s.T())
|
||||
s.storageMock = storage.NewMockProvider(s.ctrl)
|
||||
s.ctx = context.Background()
|
||||
|
||||
s.configuration = schema.RegulationConfiguration{
|
||||
MaxRetries: 3,
|
||||
|
@ -50,12 +53,12 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenUserIsLegitimate() {
|
|||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
||||
|
||||
_, err := regulator.Regulate("john")
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
assert.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
|
@ -81,12 +84,12 @@ func (s *RegulatorSuite) TestShouldNotThrowWhenFailedAuthenticationNotInFindTime
|
|||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
||||
|
||||
_, err := regulator.Regulate("john")
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
assert.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
|
@ -117,12 +120,12 @@ func (s *RegulatorSuite) TestShouldBanUserIfLatestAttemptsAreWithinFinTime() {
|
|||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
||||
|
||||
_, err := regulator.Regulate("john")
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
||||
}
|
||||
|
||||
|
@ -150,12 +153,12 @@ func (s *RegulatorSuite) TestShouldCheckUserIsStillBanned() {
|
|||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
||||
|
||||
_, err := regulator.Regulate("john")
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
||||
}
|
||||
|
||||
|
@ -174,12 +177,12 @@ func (s *RegulatorSuite) TestShouldCheckUserIsNotYetBanned() {
|
|||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
||||
|
||||
_, err := regulator.Regulate("john")
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
assert.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
|
@ -206,12 +209,12 @@ func (s *RegulatorSuite) TestShouldCheckUserWasAboutToBeBanned() {
|
|||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
||||
|
||||
_, err := regulator.Regulate("john")
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
assert.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
|
@ -242,12 +245,12 @@ func (s *RegulatorSuite) TestShouldCheckRegulationHasBeenResetOnSuccessfulAttemp
|
|||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
regulator := regulation.NewRegulator(&s.configuration, s.storageMock, &s.clock)
|
||||
|
||||
_, err := regulator.Regulate("john")
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
assert.NoError(s.T(), err)
|
||||
}
|
||||
|
||||
|
@ -277,7 +280,7 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
|
|||
}
|
||||
|
||||
s.storageMock.EXPECT().
|
||||
LoadLatestAuthenticationLogs(gomock.Eq("john"), gomock.Any()).
|
||||
LoadAuthenticationLogs(s.ctx, gomock.Eq("john"), gomock.Any(), gomock.Eq(10), gomock.Eq(0)).
|
||||
Return(attemptsInDB, nil)
|
||||
|
||||
// Check Disabled Functionality
|
||||
|
@ -288,7 +291,7 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
|
|||
}
|
||||
|
||||
regulator := regulation.NewRegulator(&configuration, s.storageMock, &s.clock)
|
||||
_, err := regulator.Regulate("john")
|
||||
_, err := regulator.Regulate(s.ctx, "john")
|
||||
assert.NoError(s.T(), err)
|
||||
|
||||
// Check Enabled Functionality
|
||||
|
@ -299,6 +302,6 @@ func (s *RegulatorSuite) TestShouldHaveRegulatorDisabled() {
|
|||
}
|
||||
|
||||
regulator = regulation.NewRegulator(&configuration, s.storageMock, &s.clock)
|
||||
_, err = regulator.Regulate("john")
|
||||
_, err = regulator.Regulate(s.ctx, "john")
|
||||
assert.Equal(s.T(), regulation.ErrUserIsBanned, err)
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ type Regulator struct {
|
|||
// If a user has been banned, this duration is the timelapse during which the user is banned.
|
||||
banTime time.Duration
|
||||
|
||||
storageProvider storage.Provider
|
||||
storageProvider storage.RegulatorProvider
|
||||
|
||||
clock utils.Clock
|
||||
}
|
||||
|
|
|
@ -1,39 +1,52 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
const storageSchemaCurrentVersion = SchemaVersion(1)
|
||||
const storageSchemaUpgradeMessage = "Storage schema upgraded to v"
|
||||
const storageSchemaUpgradeErrorText = "storage schema upgrade failed at v"
|
||||
const (
|
||||
tableUserPreferences = "user_preferences"
|
||||
tableIdentityVerification = "identity_verification_tokens"
|
||||
tableTOTPConfigurations = "totp_configurations"
|
||||
tableU2FDevices = "u2f_devices"
|
||||
tableDUODevices = "duo_devices"
|
||||
tableAuthenticationLogs = "authentication_logs"
|
||||
tableMigrations = "migrations"
|
||||
|
||||
// Keep table names in lower case because some DB does not support upper case.
|
||||
const userPreferencesTableName = "user_preferences"
|
||||
const identityVerificationTokensTableName = "identity_verification_tokens"
|
||||
const totpSecretsTableName = "totp_secrets"
|
||||
const u2fDeviceHandlesTableName = "u2f_devices"
|
||||
const authenticationLogsTableName = "authentication_logs"
|
||||
const configTableName = "config"
|
||||
tablePrefixBackup = "_bkp_"
|
||||
)
|
||||
|
||||
// sqlUpgradeCreateTableStatements is a map of the schema version number, plus a map of the table name and the statement used to create it.
|
||||
// The statement is fmt.Sprintf'd with the table name as the first argument.
|
||||
var sqlUpgradeCreateTableStatements = map[SchemaVersion]map[string]string{
|
||||
SchemaVersion(1): {
|
||||
userPreferencesTableName: "CREATE TABLE %s (username VARCHAR(100) PRIMARY KEY, second_factor_method VARCHAR(11))",
|
||||
identityVerificationTokensTableName: "CREATE TABLE %s (token VARCHAR(512))",
|
||||
totpSecretsTableName: "CREATE TABLE %s (username VARCHAR(100) PRIMARY KEY, secret VARCHAR(64))",
|
||||
u2fDeviceHandlesTableName: "CREATE TABLE %s (username VARCHAR(100) PRIMARY KEY, keyHandle TEXT, publicKey TEXT)",
|
||||
authenticationLogsTableName: "CREATE TABLE %s (username VARCHAR(100), successful BOOL, time INTEGER)",
|
||||
configTableName: "CREATE TABLE %s (category VARCHAR(32) NOT NULL, key_name VARCHAR(32) NOT NULL, value TEXT, PRIMARY KEY (category, key_name))",
|
||||
},
|
||||
}
|
||||
// WARNING: Do not change/remove these consts. They are used for Pre1 migrations.
|
||||
const (
|
||||
tablePre1TOTPSecrets = "totp_secrets"
|
||||
tablePre1Config = "config"
|
||||
tablePre1IdentityVerificationTokens = "identity_verification_tokens"
|
||||
tableAlphaAuthenticationLogs = "AuthenticationLogs"
|
||||
tableAlphaIdentityVerificationTokens = "IdentityVerificationTokens"
|
||||
tableAlphaPreferences = "Preferences"
|
||||
tableAlphaPreferencesTableName = "PreferencesTableName"
|
||||
tableAlphaSecondFactorPreferences = "SecondFactorPreferences"
|
||||
tableAlphaTOTPSecrets = "TOTPSecrets"
|
||||
tableAlphaU2FDeviceHandles = "U2FDeviceHandles"
|
||||
)
|
||||
|
||||
// sqlUpgradesCreateTableIndexesStatements is a map of t he schema version number, plus a slice of statements to create all of the indexes.
|
||||
var sqlUpgradesCreateTableIndexesStatements = map[SchemaVersion][]string{
|
||||
SchemaVersion(1): {
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s (username, time)", authenticationLogsTableName),
|
||||
},
|
||||
}
|
||||
const (
|
||||
providerAll = "all"
|
||||
providerMySQL = "mysql"
|
||||
providerPostgres = "postgres"
|
||||
providerSQLite = "sqlite"
|
||||
)
|
||||
|
||||
const unitTestUser = "john"
|
||||
const (
|
||||
// This is the latest schema version for the purpose of tests.
|
||||
testLatestVersion = 1
|
||||
)
|
||||
|
||||
const (
|
||||
// SchemaLatest represents the value expected for a "migrate to latest" migration. It's the maximum 32bit signed integer.
|
||||
SchemaLatest = 2147483647
|
||||
)
|
||||
|
||||
var (
|
||||
reMigration = regexp.MustCompile(`^V(\d{4})\.([^.]+)\.(all|sqlite|postgres|mysql)\.(up|down)\.sql$`)
|
||||
)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package storage
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoU2FDeviceHandle error thrown when no U2F device handle has been found in DB.
|
||||
|
@ -8,4 +10,35 @@ var (
|
|||
|
||||
// ErrNoTOTPSecret error thrown when no TOTP secret has been found in DB.
|
||||
ErrNoTOTPSecret = errors.New("no TOTP secret registered")
|
||||
|
||||
// ErrNoAvailableMigrations is returned when no available migrations can be found.
|
||||
ErrNoAvailableMigrations = errors.New("no available migrations")
|
||||
|
||||
// ErrSchemaAlreadyUpToDate is returned when the schema is already up to date.
|
||||
ErrSchemaAlreadyUpToDate = errors.New("schema already up to date")
|
||||
|
||||
// ErrNoMigrationsFound is returned when no migrations were found.
|
||||
ErrNoMigrationsFound = errors.New("no schema migrations found")
|
||||
)
|
||||
|
||||
// Error formats for the storage provider.
|
||||
const (
|
||||
ErrFmtMigrateUpTargetLessThanCurrent = "schema up migration target version %d is less then the current version %d"
|
||||
ErrFmtMigrateUpTargetGreaterThanLatest = "schema up migration target version %d is greater then the latest version %d which indicates it doesn't exist"
|
||||
ErrFmtMigrateDownTargetGreaterThanCurrent = "schema down migration target version %d is greater than the current version %d"
|
||||
ErrFmtMigrateDownTargetLessThanMinimum = "schema down migration target version %d is less than the minimum version"
|
||||
ErrFmtMigrateAlreadyOnTargetVersion = "schema migration target version %d is the same current version %d"
|
||||
)
|
||||
|
||||
const (
|
||||
errFmtFailedMigration = "schema migration %d (%s) failed: %w"
|
||||
errFmtFailedMigrationPre1 = "schema migration pre1 failed: %w"
|
||||
errFmtSchemaCurrentGreaterThanLatestKnown = "current schema version is greater than the latest known schema " +
|
||||
"version, you must downgrade to schema version %d before you can use this version of Authelia"
|
||||
)
|
||||
|
||||
const (
|
||||
logFmtMigrationFromTo = "Storage schema migration from %s to %s is being attempted"
|
||||
logFmtMigrationComplete = "Storage schema migration from %s to %s is complete"
|
||||
logFmtErrClosingConn = "Error occurred closing SQL connection: %v"
|
||||
)
|
||||
|
|
204
internal/storage/migrations.go
Normal file
204
internal/storage/migrations.go
Normal file
|
@ -0,0 +1,204 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//go:embed migrations/*
|
||||
var migrationsFS embed.FS
|
||||
|
||||
func latestMigrationVersion(providerName string) (version int, err error) {
|
||||
entries, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := scanMigration(entry.Name())
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
if m.Provider != providerName {
|
||||
continue
|
||||
}
|
||||
|
||||
if !m.Up {
|
||||
continue
|
||||
}
|
||||
|
||||
if m.Version > version {
|
||||
version = m.Version
|
||||
}
|
||||
}
|
||||
|
||||
return version, nil
|
||||
}
|
||||
|
||||
func loadMigration(providerName string, version int, up bool) (migration *SchemaMigration, err error) {
|
||||
entries, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := scanMigration(entry.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migration = &m
|
||||
|
||||
if up != migration.Up {
|
||||
continue
|
||||
}
|
||||
|
||||
if migration.Provider != providerAll && migration.Provider != providerName {
|
||||
continue
|
||||
}
|
||||
|
||||
if version != migration.Version {
|
||||
continue
|
||||
}
|
||||
|
||||
return migration, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("migration not found")
|
||||
}
|
||||
|
||||
// loadMigrations scans the migrations fs and loads the appropriate migrations for a given providerName, prior and
|
||||
// target versions. If the target version is -1 this indicates the latest version. If the target version is 0
|
||||
// this indicates the database zero state.
|
||||
func loadMigrations(providerName string, prior, target int) (migrations []SchemaMigration, err error) {
|
||||
if prior == target && (prior != -1 || target != -1) {
|
||||
return nil, errors.New("cannot migrate to the same version as prior")
|
||||
}
|
||||
|
||||
entries, err := migrationsFS.ReadDir("migrations")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
up := prior < target
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
migration, err := scanMigration(entry.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if skipMigration(providerName, up, target, prior, &migration) {
|
||||
continue
|
||||
}
|
||||
|
||||
migrations = append(migrations, migration)
|
||||
}
|
||||
|
||||
if up {
|
||||
sort.Slice(migrations, func(i, j int) bool {
|
||||
return migrations[i].Version < migrations[j].Version
|
||||
})
|
||||
} else {
|
||||
sort.Slice(migrations, func(i, j int) bool {
|
||||
return migrations[i].Version > migrations[j].Version
|
||||
})
|
||||
}
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
func skipMigration(providerName string, up bool, target, prior int, migration *SchemaMigration) (skip bool) {
|
||||
if migration.Provider != providerAll && migration.Provider != providerName {
|
||||
// Skip if migration.Provider is not a match.
|
||||
return true
|
||||
}
|
||||
|
||||
if up {
|
||||
if !migration.Up {
|
||||
// Skip if we wanted an Up migration but it isn't an Up migration.
|
||||
return true
|
||||
}
|
||||
|
||||
if target != -1 && (migration.Version > target || migration.Version <= prior) {
|
||||
// Skip if the migration version is greater than the target or less than or equal to the previous version.
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if migration.Up {
|
||||
// Skip if we didn't want an Up migration but it is an Up migration.
|
||||
return true
|
||||
}
|
||||
|
||||
if migration.Version == 1 && target == -1 {
|
||||
// Skip if we're targeting pre1 and the migration version is 1 as this migration will destroy all data
|
||||
// preventing a successful migration.
|
||||
return true
|
||||
}
|
||||
|
||||
if migration.Version <= target || migration.Version > prior {
|
||||
// Skip the migration if we want to go down and the migration version is less than or equal to the target
|
||||
// or greater than the previous version.
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func scanMigration(m string) (migration SchemaMigration, err error) {
|
||||
result := reMigration.FindStringSubmatch(m)
|
||||
|
||||
if result == nil || len(result) != 5 {
|
||||
return SchemaMigration{}, errors.New("invalid migration: could not parse the format")
|
||||
}
|
||||
|
||||
migration = SchemaMigration{
|
||||
Name: strings.ReplaceAll(result[2], "_", " "),
|
||||
Provider: result[3],
|
||||
}
|
||||
|
||||
data, err := migrationsFS.ReadFile(fmt.Sprintf("migrations/%s", m))
|
||||
if err != nil {
|
||||
return SchemaMigration{}, err
|
||||
}
|
||||
|
||||
migration.Query = string(data)
|
||||
|
||||
switch result[4] {
|
||||
case "up":
|
||||
migration.Up = true
|
||||
case "down":
|
||||
migration.Up = false
|
||||
default:
|
||||
return SchemaMigration{}, fmt.Errorf("invalid migration: value in position 4 '%s' must be up or down", result[4])
|
||||
}
|
||||
|
||||
migration.Version, _ = strconv.Atoi(result[1])
|
||||
|
||||
switch migration.Provider {
|
||||
case providerAll, providerSQLite, providerMySQL, providerPostgres:
|
||||
break
|
||||
default:
|
||||
return SchemaMigration{}, fmt.Errorf("invalid migration: value in position 3 '%s' must be all, sqlite, postgres, or mysql", result[3])
|
||||
}
|
||||
|
||||
return migration, nil
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
DROP TABLE IF EXISTS authentication_logs;
|
||||
DROP TABLE IF EXISTS identity_verification_tokens;
|
||||
DROP TABLE IF EXISTS totp_configurations;
|
||||
DROP TABLE IF EXISTS u2f_devices;
|
||||
DROP TABLE IF EXISTS user_preferences;
|
||||
DROP TABLE IF EXISTS migrations;
|
|
@ -0,0 +1,55 @@
|
|||
CREATE TABLE IF NOT EXISTS authentication_logs (
|
||||
id INTEGER AUTO_INCREMENT,
|
||||
time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
successful BOOL NOT NULL,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
CREATE INDEX authentication_logs_username_idx ON authentication_logs (time, username);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS identity_verification_tokens (
|
||||
id INTEGER AUTO_INCREMENT,
|
||||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
token VARCHAR(512),
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE KEY (token)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS totp_configurations (
|
||||
id INTEGER AUTO_INCREMENT,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1',
|
||||
digits INTEGER NOT NULL DEFAULT 6,
|
||||
totp_period INTEGER NOT NULL DEFAULT 30,
|
||||
secret VARCHAR(64) NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE KEY (username)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS u2f_devices (
|
||||
id INTEGER AUTO_INCREMENT,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
description VARCHAR(30) NOT NULL DEFAULT 'Primary',
|
||||
key_handle BLOB NOT NULL,
|
||||
public_key BLOB NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE KEY (username, description)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_preferences (
|
||||
id INTEGER AUTO_INCREMENT,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
second_factor_method VARCHAR(11) NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE KEY (username)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
id INTEGER AUTO_INCREMENT,
|
||||
applied TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
version_before INTEGER NULL DEFAULT NULL,
|
||||
version_after INTEGER NOT NULL,
|
||||
application_version VARCHAR(128) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
|
@ -0,0 +1,55 @@
|
|||
CREATE TABLE IF NOT EXISTS authentication_logs (
|
||||
id SERIAL,
|
||||
time TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
successful BOOLEAN NOT NULL,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
CREATE INDEX authentication_logs_username_idx ON authentication_logs (time, username);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS identity_verification_tokens (
|
||||
id SERIAL,
|
||||
created TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
token VARCHAR(512),
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE (token)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS totp_configurations (
|
||||
id SERIAL,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1',
|
||||
digits INTEGER NOT NULL DEFAULT 6,
|
||||
totp_period INTEGER NOT NULL DEFAULT 30,
|
||||
secret VARCHAR(64) NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE (username)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS u2f_devices (
|
||||
id SERIAL,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
description VARCHAR(30) NOT NULL DEFAULT 'Primary',
|
||||
key_handle BYTEA NOT NULL,
|
||||
public_key BYTEA NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE (username, description)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_preferences (
|
||||
id SERIAL,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
second_factor_method VARCHAR(11) NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE (username)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
id SERIAL,
|
||||
applied TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
version_before INTEGER NULL DEFAULT NULL,
|
||||
version_after INTEGER NOT NULL,
|
||||
application_version VARCHAR(128) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
|
@ -0,0 +1,54 @@
|
|||
CREATE TABLE IF NOT EXISTS authentication_logs (
|
||||
id INTEGER,
|
||||
time TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
successful BOOLEAN NOT NULL,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
CREATE INDEX authentication_logs_username_idx ON authentication_logs (time, username);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS identity_verification_tokens (
|
||||
id INTEGER,
|
||||
created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
token VARCHAR(512),
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE (token)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS totp_configurations (
|
||||
id INTEGER,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
algorithm VARCHAR(6) NOT NULL DEFAULT 'SHA1',
|
||||
digits INTEGER(1) NOT NULL DEFAULT 6,
|
||||
totp_period INTEGER NOT NULL DEFAULT 30,
|
||||
secret VARCHAR(64) NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE (username)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS u2f_devices (
|
||||
id INTEGER,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
description VARCHAR(30) NOT NULL DEFAULT 'Primary',
|
||||
key_handle BLOB NOT NULL,
|
||||
public_key BLOB NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE (username, description)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_preferences (
|
||||
id INTEGER,
|
||||
username VARCHAR(100) UNIQUE NOT NULL,
|
||||
second_factor_method VARCHAR(11) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS migrations (
|
||||
id INTEGER,
|
||||
applied TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
version_before INTEGER NULL DEFAULT NULL,
|
||||
version_after INTEGER NOT NULL,
|
||||
application_version VARCHAR(128) NOT NULL,
|
||||
PRIMARY KEY (id)
|
||||
);
|
154
internal/storage/migrations_test.go
Normal file
154
internal/storage/migrations_test.go
Normal file
|
@ -0,0 +1,154 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestShouldObtainCorrectUpMigrations(t *testing.T) {
|
||||
ver, err := latestMigrationVersion(providerSQLite)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, testLatestVersion, ver)
|
||||
|
||||
migrations, err := loadMigrations(providerSQLite, 0, ver)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, migrations, ver)
|
||||
|
||||
for i := 0; i < len(migrations); i++ {
|
||||
assert.Equal(t, i+1, migrations[i].Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldObtainCorrectDownMigrations(t *testing.T) {
|
||||
ver, err := latestMigrationVersion(providerSQLite)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, testLatestVersion, ver)
|
||||
|
||||
migrations, err := loadMigrations(providerSQLite, ver, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, migrations, ver)
|
||||
|
||||
for i := 0; i < len(migrations); i++ {
|
||||
assert.Equal(t, ver-i, migrations[i].Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationsShouldNotBeDuplicatedPostgres(t *testing.T) {
|
||||
migrations, err := loadMigrations(providerPostgres, 0, SchemaLatest)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, 0, len(migrations))
|
||||
|
||||
previousUp := make([]int, len(migrations))
|
||||
|
||||
for i, migration := range migrations {
|
||||
assert.True(t, migration.Up)
|
||||
|
||||
if i != 0 {
|
||||
for _, v := range previousUp {
|
||||
assert.NotEqual(t, v, migration.Version)
|
||||
}
|
||||
}
|
||||
|
||||
previousUp = append(previousUp, migration.Version)
|
||||
}
|
||||
|
||||
migrations, err = loadMigrations(providerPostgres, SchemaLatest, 0)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, 0, len(migrations))
|
||||
|
||||
previousDown := make([]int, len(migrations))
|
||||
|
||||
for i, migration := range migrations {
|
||||
assert.False(t, migration.Up)
|
||||
|
||||
if i != 0 {
|
||||
for _, v := range previousDown {
|
||||
assert.NotEqual(t, v, migration.Version)
|
||||
}
|
||||
}
|
||||
|
||||
previousDown = append(previousDown, migration.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationsShouldNotBeDuplicatedMySQL(t *testing.T) {
|
||||
migrations, err := loadMigrations(providerMySQL, 0, SchemaLatest)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, 0, len(migrations))
|
||||
|
||||
previousUp := make([]int, len(migrations))
|
||||
|
||||
for i, migration := range migrations {
|
||||
assert.True(t, migration.Up)
|
||||
|
||||
if i != 0 {
|
||||
for _, v := range previousUp {
|
||||
assert.NotEqual(t, v, migration.Version)
|
||||
}
|
||||
}
|
||||
|
||||
previousUp = append(previousUp, migration.Version)
|
||||
}
|
||||
|
||||
migrations, err = loadMigrations(providerMySQL, SchemaLatest, 0)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, 0, len(migrations))
|
||||
|
||||
previousDown := make([]int, len(migrations))
|
||||
|
||||
for i, migration := range migrations {
|
||||
assert.False(t, migration.Up)
|
||||
|
||||
if i != 0 {
|
||||
for _, v := range previousDown {
|
||||
assert.NotEqual(t, v, migration.Version)
|
||||
}
|
||||
}
|
||||
|
||||
previousDown = append(previousDown, migration.Version)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationsShouldNotBeDuplicatedSQLite(t *testing.T) {
|
||||
migrations, err := loadMigrations(providerSQLite, 0, SchemaLatest)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, 0, len(migrations))
|
||||
|
||||
previousUp := make([]int, len(migrations))
|
||||
|
||||
for i, migration := range migrations {
|
||||
assert.True(t, migration.Up)
|
||||
|
||||
if i != 0 {
|
||||
for _, v := range previousUp {
|
||||
assert.NotEqual(t, v, migration.Version)
|
||||
}
|
||||
}
|
||||
|
||||
previousUp = append(previousUp, migration.Version)
|
||||
}
|
||||
|
||||
migrations, err = loadMigrations(providerSQLite, SchemaLatest, 0)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, 0, len(migrations))
|
||||
|
||||
previousDown := make([]int, len(migrations))
|
||||
|
||||
for i, migration := range migrations {
|
||||
assert.False(t, migration.Up)
|
||||
|
||||
if i != 0 {
|
||||
for _, v := range previousDown {
|
||||
assert.NotEqual(t, v, migration.Version)
|
||||
}
|
||||
}
|
||||
|
||||
previousDown = append(previousDown, migration.Version)
|
||||
}
|
||||
}
|
|
@ -1,85 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql" // Load the MySQL Driver used in the connection string.
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
||||
// MySQLProvider is a MySQL provider.
|
||||
type MySQLProvider struct {
|
||||
SQLProvider
|
||||
}
|
||||
|
||||
// NewMySQLProvider a MySQL provider.
|
||||
func NewMySQLProvider(configuration schema.MySQLStorageConfiguration) *MySQLProvider {
|
||||
provider := MySQLProvider{
|
||||
SQLProvider{
|
||||
name: "mysql",
|
||||
|
||||
sqlUpgradesCreateTableStatements: sqlUpgradeCreateTableStatements,
|
||||
|
||||
sqlGetPreferencesByUsername: fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=?", userPreferencesTableName),
|
||||
sqlUpsertSecondFactorPreference: fmt.Sprintf("REPLACE INTO %s (username, second_factor_method) VALUES (?, ?)", userPreferencesTableName),
|
||||
|
||||
sqlTestIdentityVerificationTokenExistence: fmt.Sprintf("SELECT EXISTS (SELECT * FROM %s WHERE token=?)", identityVerificationTokensTableName),
|
||||
sqlInsertIdentityVerificationToken: fmt.Sprintf("INSERT INTO %s (token) VALUES (?)", identityVerificationTokensTableName),
|
||||
sqlDeleteIdentityVerificationToken: fmt.Sprintf("DELETE FROM %s WHERE token=?", identityVerificationTokensTableName),
|
||||
|
||||
sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName),
|
||||
sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName),
|
||||
sqlDeleteTOTPSecret: fmt.Sprintf("DELETE FROM %s WHERE username=?", totpSecretsTableName),
|
||||
|
||||
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName),
|
||||
sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName),
|
||||
|
||||
sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName),
|
||||
sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName),
|
||||
|
||||
sqlGetExistingTables: "SELECT table_name FROM information_schema.tables WHERE table_type='BASE TABLE' AND table_schema=database()",
|
||||
|
||||
sqlConfigSetValue: fmt.Sprintf("REPLACE INTO %s (category, key_name, value) VALUES (?, ?, ?)", configTableName),
|
||||
sqlConfigGetValue: fmt.Sprintf("SELECT value FROM %s WHERE category=? AND key_name=?", configTableName),
|
||||
},
|
||||
}
|
||||
|
||||
provider.sqlUpgradesCreateTableStatements[SchemaVersion(1)][authenticationLogsTableName] = "CREATE TABLE %s (username VARCHAR(100), successful BOOL, time INTEGER, INDEX usr_time_idx (username, time))"
|
||||
|
||||
connectionString := configuration.Username
|
||||
|
||||
if configuration.Password != "" {
|
||||
connectionString += fmt.Sprintf(":%s", configuration.Password)
|
||||
}
|
||||
|
||||
if connectionString != "" {
|
||||
connectionString += "@"
|
||||
}
|
||||
|
||||
address := configuration.Host
|
||||
if configuration.Port > 0 {
|
||||
address += fmt.Sprintf(":%d", configuration.Port)
|
||||
}
|
||||
|
||||
connectionString += fmt.Sprintf("tcp(%s)", address)
|
||||
if configuration.Database != "" {
|
||||
connectionString += fmt.Sprintf("/%s", configuration.Database)
|
||||
}
|
||||
|
||||
connectionString += "?"
|
||||
connectionString += fmt.Sprintf("timeout=%ds", int32(configuration.Timeout/time.Second))
|
||||
|
||||
db, err := sql.Open("mysql", connectionString)
|
||||
if err != nil {
|
||||
provider.log.Fatalf("Unable to connect to SQL database: %v", err)
|
||||
}
|
||||
|
||||
if err := provider.initialize(db); err != nil {
|
||||
provider.log.Fatalf("Unable to initialize SQL database: %v", err)
|
||||
}
|
||||
|
||||
return &provider
|
||||
}
|
|
@ -1,90 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v4/stdlib" // Load the PostgreSQL Driver used in the connection string.
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
||||
// PostgreSQLProvider is a PostgreSQL provider.
|
||||
type PostgreSQLProvider struct {
|
||||
SQLProvider
|
||||
}
|
||||
|
||||
// NewPostgreSQLProvider a PostgreSQL provider.
|
||||
func NewPostgreSQLProvider(configuration schema.PostgreSQLStorageConfiguration) *PostgreSQLProvider {
|
||||
provider := PostgreSQLProvider{
|
||||
SQLProvider{
|
||||
name: "postgres",
|
||||
|
||||
sqlUpgradesCreateTableStatements: sqlUpgradeCreateTableStatements,
|
||||
sqlUpgradesCreateTableIndexesStatements: sqlUpgradesCreateTableIndexesStatements,
|
||||
|
||||
sqlGetPreferencesByUsername: fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=$1", userPreferencesTableName),
|
||||
sqlUpsertSecondFactorPreference: fmt.Sprintf("INSERT INTO %s (username, second_factor_method) VALUES ($1, $2) ON CONFLICT (username) DO UPDATE SET second_factor_method=$2", userPreferencesTableName),
|
||||
|
||||
sqlTestIdentityVerificationTokenExistence: fmt.Sprintf("SELECT EXISTS (SELECT * FROM %s WHERE token=$1)", identityVerificationTokensTableName),
|
||||
sqlInsertIdentityVerificationToken: fmt.Sprintf("INSERT INTO %s (token) VALUES ($1)", identityVerificationTokensTableName),
|
||||
sqlDeleteIdentityVerificationToken: fmt.Sprintf("DELETE FROM %s WHERE token=$1", identityVerificationTokensTableName),
|
||||
|
||||
sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=$1", totpSecretsTableName),
|
||||
sqlUpsertTOTPSecret: fmt.Sprintf("INSERT INTO %s (username, secret) VALUES ($1, $2) ON CONFLICT (username) DO UPDATE SET secret=$2", totpSecretsTableName),
|
||||
sqlDeleteTOTPSecret: fmt.Sprintf("DELETE FROM %s WHERE username=$1", totpSecretsTableName),
|
||||
|
||||
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=$1", u2fDeviceHandlesTableName),
|
||||
sqlUpsertU2FDeviceHandle: fmt.Sprintf("INSERT INTO %s (username, keyHandle, publicKey) VALUES ($1, $2, $3) ON CONFLICT (username) DO UPDATE SET keyHandle=$2, publicKey=$3", u2fDeviceHandlesTableName),
|
||||
|
||||
sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES ($1, $2, $3)", authenticationLogsTableName),
|
||||
sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>$1 AND username=$2 ORDER BY time DESC", authenticationLogsTableName),
|
||||
|
||||
sqlGetExistingTables: "SELECT table_name FROM information_schema.tables WHERE table_type='BASE TABLE' AND table_schema='public'",
|
||||
|
||||
sqlConfigSetValue: fmt.Sprintf("INSERT INTO %s (category, key_name, value) VALUES ($1, $2, $3) ON CONFLICT (category, key_name) DO UPDATE SET value=$3", configTableName),
|
||||
sqlConfigGetValue: fmt.Sprintf("SELECT value FROM %s WHERE category=$1 AND key_name=$2", configTableName),
|
||||
},
|
||||
}
|
||||
|
||||
args := make([]string, 0)
|
||||
if configuration.Username != "" {
|
||||
args = append(args, fmt.Sprintf("user='%s'", configuration.Username))
|
||||
}
|
||||
|
||||
if configuration.Password != "" {
|
||||
args = append(args, fmt.Sprintf("password='%s'", configuration.Password))
|
||||
}
|
||||
|
||||
if configuration.Host != "" {
|
||||
args = append(args, fmt.Sprintf("host=%s", configuration.Host))
|
||||
}
|
||||
|
||||
if configuration.Port > 0 {
|
||||
args = append(args, fmt.Sprintf("port=%d", configuration.Port))
|
||||
}
|
||||
|
||||
if configuration.Database != "" {
|
||||
args = append(args, fmt.Sprintf("dbname=%s", configuration.Database))
|
||||
}
|
||||
|
||||
if configuration.SSLMode != "" {
|
||||
args = append(args, fmt.Sprintf("sslmode=%s", configuration.SSLMode))
|
||||
}
|
||||
|
||||
args = append(args, fmt.Sprintf("connect_timeout=%d", int32(configuration.Timeout/time.Second)))
|
||||
connectionString := strings.Join(args, " ")
|
||||
|
||||
db, err := sql.Open("pgx", connectionString)
|
||||
if err != nil {
|
||||
provider.log.Fatalf("Unable to connect to SQL database: %v", err)
|
||||
}
|
||||
|
||||
if err := provider.initialize(db); err != nil {
|
||||
provider.log.Fatalf("Unable to initialize SQL database: %v", err)
|
||||
}
|
||||
|
||||
return &provider
|
||||
}
|
|
@ -1,28 +1,45 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
)
|
||||
|
||||
// Provider is an interface providing storage capabilities for
|
||||
// persisting any kind of data related to Authelia.
|
||||
// Provider is an interface providing storage capabilities for persisting any kind of data related to Authelia.
|
||||
type Provider interface {
|
||||
LoadPreferred2FAMethod(username string) (string, error)
|
||||
SavePreferred2FAMethod(username string, method string) error
|
||||
models.StartupCheck
|
||||
|
||||
FindIdentityVerificationToken(token string) (bool, error)
|
||||
SaveIdentityVerificationToken(token string) error
|
||||
RemoveIdentityVerificationToken(token string) error
|
||||
RegulatorProvider
|
||||
|
||||
SaveTOTPSecret(username string, secret string) error
|
||||
LoadTOTPSecret(username string) (string, error)
|
||||
DeleteTOTPSecret(username string) error
|
||||
SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error)
|
||||
LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error)
|
||||
LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error)
|
||||
|
||||
SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error
|
||||
LoadU2FDeviceHandle(username string) (keyHandle []byte, publicKey []byte, err error)
|
||||
SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error)
|
||||
RemoveIdentityVerification(ctx context.Context, jti string) (err error)
|
||||
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
|
||||
|
||||
AppendAuthenticationLog(attempt models.AuthenticationAttempt) error
|
||||
LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error)
|
||||
SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error)
|
||||
DeleteTOTPConfiguration(ctx context.Context, username string) (err error)
|
||||
LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error)
|
||||
|
||||
SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error)
|
||||
LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error)
|
||||
|
||||
SchemaTables(ctx context.Context) (tables []string, err error)
|
||||
SchemaVersion(ctx context.Context) (version int, err error)
|
||||
SchemaMigrate(ctx context.Context, up bool, version int) (err error)
|
||||
SchemaMigrationHistory(ctx context.Context) (migrations []models.Migration, err error)
|
||||
|
||||
SchemaLatestVersion() (version int, err error)
|
||||
SchemaMigrationsUp(ctx context.Context, version int) (migrations []SchemaMigration, err error)
|
||||
SchemaMigrationsDown(ctx context.Context, version int) (migrations []SchemaMigration, err error)
|
||||
}
|
||||
|
||||
// RegulatorProvider is an interface providing storage capabilities for persisting any kind of data related to the regulator.
|
||||
type RegulatorProvider interface {
|
||||
AppendAuthenticationLog(ctx context.Context, attempt models.AuthenticationAttempt) (err error)
|
||||
LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []models.AuthenticationAttempt, err error)
|
||||
}
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: internal/storage/provider.go
|
||||
// Source: ./internal/storage/provider.go
|
||||
|
||||
// Package storage is a generated GoMock package.
|
||||
package storage
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
|
@ -13,199 +13,331 @@ import (
|
|||
models "github.com/authelia/authelia/v4/internal/models"
|
||||
)
|
||||
|
||||
// MockProvider is a mock of Provider interface
|
||||
// MockProvider is a mock of Provider interface.
|
||||
type MockProvider struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockProviderMockRecorder
|
||||
}
|
||||
|
||||
// MockProviderMockRecorder is the mock recorder for MockProvider
|
||||
// MockProviderMockRecorder is the mock recorder for MockProvider.
|
||||
type MockProviderMockRecorder struct {
|
||||
mock *MockProvider
|
||||
}
|
||||
|
||||
// NewMockProvider creates a new mock instance
|
||||
// NewMockProvider creates a new mock instance.
|
||||
func NewMockProvider(ctrl *gomock.Controller) *MockProvider {
|
||||
mock := &MockProvider{ctrl: ctrl}
|
||||
mock.recorder = &MockProviderMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockProvider) EXPECT() *MockProviderMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// LoadPreferred2FAMethod mocks base method
|
||||
func (m *MockProvider) LoadPreferred2FAMethod(username string) (string, error) {
|
||||
// AppendAuthenticationLog mocks base method.
|
||||
func (m *MockProvider) AppendAuthenticationLog(arg0 context.Context, arg1 models.AuthenticationAttempt) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoadPreferred2FAMethod", username)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LoadPreferred2FAMethod indicates an expected call of LoadPreferred2FAMethod
|
||||
func (mr *MockProviderMockRecorder) LoadPreferred2FAMethod(username interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).LoadPreferred2FAMethod), username)
|
||||
}
|
||||
|
||||
// SavePreferred2FAMethod mocks base method
|
||||
func (m *MockProvider) SavePreferred2FAMethod(username, method string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SavePreferred2FAMethod", username, method)
|
||||
ret := m.ctrl.Call(m, "AppendAuthenticationLog", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SavePreferred2FAMethod indicates an expected call of SavePreferred2FAMethod
|
||||
func (mr *MockProviderMockRecorder) SavePreferred2FAMethod(username, method interface{}) *gomock.Call {
|
||||
// AppendAuthenticationLog indicates an expected call of AppendAuthenticationLog.
|
||||
func (mr *MockProviderMockRecorder) AppendAuthenticationLog(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).SavePreferred2FAMethod), username, method)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockProvider)(nil).AppendAuthenticationLog), arg0, arg1)
|
||||
}
|
||||
|
||||
// FindIdentityVerificationToken mocks base method
|
||||
func (m *MockProvider) FindIdentityVerificationToken(token string) (bool, error) {
|
||||
// DeleteTOTPConfiguration mocks base method.
|
||||
func (m *MockProvider) DeleteTOTPConfiguration(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FindIdentityVerificationToken", token)
|
||||
ret := m.ctrl.Call(m, "DeleteTOTPConfiguration", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteTOTPConfiguration indicates an expected call of DeleteTOTPConfiguration.
|
||||
func (mr *MockProviderMockRecorder) DeleteTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTOTPConfiguration", reflect.TypeOf((*MockProvider)(nil).DeleteTOTPConfiguration), arg0, arg1)
|
||||
}
|
||||
|
||||
// FindIdentityVerification mocks base method.
|
||||
func (m *MockProvider) FindIdentityVerification(arg0 context.Context, arg1 string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FindIdentityVerification", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FindIdentityVerificationToken indicates an expected call of FindIdentityVerificationToken
|
||||
func (mr *MockProviderMockRecorder) FindIdentityVerificationToken(token interface{}) *gomock.Call {
|
||||
// FindIdentityVerification indicates an expected call of FindIdentityVerification.
|
||||
func (mr *MockProviderMockRecorder) FindIdentityVerification(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).FindIdentityVerificationToken), token)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindIdentityVerification", reflect.TypeOf((*MockProvider)(nil).FindIdentityVerification), arg0, arg1)
|
||||
}
|
||||
|
||||
// SaveIdentityVerificationToken mocks base method
|
||||
func (m *MockProvider) SaveIdentityVerificationToken(token string) error {
|
||||
// LoadAuthenticationLogs mocks base method.
|
||||
func (m *MockProvider) LoadAuthenticationLogs(arg0 context.Context, arg1 string, arg2 time.Time, arg3, arg4 int) ([]models.AuthenticationAttempt, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveIdentityVerificationToken", token)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveIdentityVerificationToken indicates an expected call of SaveIdentityVerificationToken
|
||||
func (mr *MockProviderMockRecorder) SaveIdentityVerificationToken(token interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).SaveIdentityVerificationToken), token)
|
||||
}
|
||||
|
||||
// RemoveIdentityVerificationToken mocks base method
|
||||
func (m *MockProvider) RemoveIdentityVerificationToken(token string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemoveIdentityVerificationToken", token)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemoveIdentityVerificationToken indicates an expected call of RemoveIdentityVerificationToken
|
||||
func (mr *MockProviderMockRecorder) RemoveIdentityVerificationToken(token interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerificationToken", reflect.TypeOf((*MockProvider)(nil).RemoveIdentityVerificationToken), token)
|
||||
}
|
||||
|
||||
// SaveTOTPSecret mocks base method
|
||||
func (m *MockProvider) SaveTOTPSecret(username, secret string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveTOTPSecret", username, secret)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveTOTPSecret indicates an expected call of SaveTOTPSecret
|
||||
func (mr *MockProviderMockRecorder) SaveTOTPSecret(username, secret interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPSecret", reflect.TypeOf((*MockProvider)(nil).SaveTOTPSecret), username, secret)
|
||||
}
|
||||
|
||||
// LoadTOTPSecret mocks base method
|
||||
func (m *MockProvider) LoadTOTPSecret(username string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoadTOTPSecret", username)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LoadTOTPSecret indicates an expected call of LoadTOTPSecret
|
||||
func (mr *MockProviderMockRecorder) LoadTOTPSecret(username interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPSecret", reflect.TypeOf((*MockProvider)(nil).LoadTOTPSecret), username)
|
||||
}
|
||||
|
||||
// DeleteTOTPSecret mocks base method
|
||||
func (m *MockProvider) DeleteTOTPSecret(username string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteTOTPSecret", username)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteTOTPSecret indicates an expected call of DeleteTOTPSecret
|
||||
func (mr *MockProviderMockRecorder) DeleteTOTPSecret(username interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTOTPSecret", reflect.TypeOf((*MockProvider)(nil).DeleteTOTPSecret), username)
|
||||
}
|
||||
|
||||
// SaveU2FDeviceHandle mocks base method
|
||||
func (m *MockProvider) SaveU2FDeviceHandle(username string, keyHandle, publicKey []byte) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveU2FDeviceHandle", username, keyHandle, publicKey)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveU2FDeviceHandle indicates an expected call of SaveU2FDeviceHandle
|
||||
func (mr *MockProviderMockRecorder) SaveU2FDeviceHandle(username, keyHandle, publicKey interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveU2FDeviceHandle", reflect.TypeOf((*MockProvider)(nil).SaveU2FDeviceHandle), username, keyHandle, publicKey)
|
||||
}
|
||||
|
||||
// LoadU2FDeviceHandle mocks base method
|
||||
func (m *MockProvider) LoadU2FDeviceHandle(username string) ([]byte, []byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoadU2FDeviceHandle", username)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].([]byte)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// LoadU2FDeviceHandle indicates an expected call of LoadU2FDeviceHandle
|
||||
func (mr *MockProviderMockRecorder) LoadU2FDeviceHandle(username interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDeviceHandle", reflect.TypeOf((*MockProvider)(nil).LoadU2FDeviceHandle), username)
|
||||
}
|
||||
|
||||
// AppendAuthenticationLog mocks base method
|
||||
func (m *MockProvider) AppendAuthenticationLog(attempt models.AuthenticationAttempt) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AppendAuthenticationLog", attempt)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AppendAuthenticationLog indicates an expected call of AppendAuthenticationLog
|
||||
func (mr *MockProviderMockRecorder) AppendAuthenticationLog(attempt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuthenticationLog", reflect.TypeOf((*MockProvider)(nil).AppendAuthenticationLog), attempt)
|
||||
}
|
||||
|
||||
// LoadLatestAuthenticationLogs mocks base method
|
||||
func (m *MockProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoadLatestAuthenticationLogs", username, fromDate)
|
||||
ret := m.ctrl.Call(m, "LoadAuthenticationLogs", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].([]models.AuthenticationAttempt)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LoadLatestAuthenticationLogs indicates an expected call of LoadLatestAuthenticationLogs
|
||||
func (mr *MockProviderMockRecorder) LoadLatestAuthenticationLogs(username, fromDate interface{}) *gomock.Call {
|
||||
// LoadAuthenticationLogs indicates an expected call of LoadAuthenticationLogs.
|
||||
func (mr *MockProviderMockRecorder) LoadAuthenticationLogs(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadLatestAuthenticationLogs", reflect.TypeOf((*MockProvider)(nil).LoadLatestAuthenticationLogs), username, fromDate)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadAuthenticationLogs", reflect.TypeOf((*MockProvider)(nil).LoadAuthenticationLogs), arg0, arg1, arg2, arg3, arg4)
|
||||
}
|
||||
|
||||
// LoadPreferred2FAMethod mocks base method.
|
||||
func (m *MockProvider) LoadPreferred2FAMethod(arg0 context.Context, arg1 string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoadPreferred2FAMethod", arg0, arg1)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LoadPreferred2FAMethod indicates an expected call of LoadPreferred2FAMethod.
|
||||
func (mr *MockProviderMockRecorder) LoadPreferred2FAMethod(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadPreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).LoadPreferred2FAMethod), arg0, arg1)
|
||||
}
|
||||
|
||||
// LoadTOTPConfiguration mocks base method.
|
||||
func (m *MockProvider) LoadTOTPConfiguration(arg0 context.Context, arg1 string) (*models.TOTPConfiguration, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoadTOTPConfiguration", arg0, arg1)
|
||||
ret0, _ := ret[0].(*models.TOTPConfiguration)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LoadTOTPConfiguration indicates an expected call of LoadTOTPConfiguration.
|
||||
func (mr *MockProviderMockRecorder) LoadTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadTOTPConfiguration", reflect.TypeOf((*MockProvider)(nil).LoadTOTPConfiguration), arg0, arg1)
|
||||
}
|
||||
|
||||
// LoadU2FDevice mocks base method.
|
||||
func (m *MockProvider) LoadU2FDevice(arg0 context.Context, arg1 string) (*models.U2FDevice, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoadU2FDevice", arg0, arg1)
|
||||
ret0, _ := ret[0].(*models.U2FDevice)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LoadU2FDevice indicates an expected call of LoadU2FDevice.
|
||||
func (mr *MockProviderMockRecorder) LoadU2FDevice(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevice", reflect.TypeOf((*MockProvider)(nil).LoadU2FDevice), arg0, arg1)
|
||||
}
|
||||
|
||||
// LoadUserInfo mocks base method.
|
||||
func (m *MockProvider) LoadUserInfo(arg0 context.Context, arg1 string) (models.UserInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoadUserInfo", arg0, arg1)
|
||||
ret0, _ := ret[0].(models.UserInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LoadUserInfo indicates an expected call of LoadUserInfo.
|
||||
func (mr *MockProviderMockRecorder) LoadUserInfo(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockProvider)(nil).LoadUserInfo), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemoveIdentityVerification mocks base method.
|
||||
func (m *MockProvider) RemoveIdentityVerification(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemoveIdentityVerification", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemoveIdentityVerification indicates an expected call of RemoveIdentityVerification.
|
||||
func (mr *MockProviderMockRecorder) RemoveIdentityVerification(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerification", reflect.TypeOf((*MockProvider)(nil).RemoveIdentityVerification), arg0, arg1)
|
||||
}
|
||||
|
||||
// SaveIdentityVerification mocks base method.
|
||||
func (m *MockProvider) SaveIdentityVerification(arg0 context.Context, arg1 models.IdentityVerification) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveIdentityVerification", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveIdentityVerification indicates an expected call of SaveIdentityVerification.
|
||||
func (mr *MockProviderMockRecorder) SaveIdentityVerification(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveIdentityVerification", reflect.TypeOf((*MockProvider)(nil).SaveIdentityVerification), arg0, arg1)
|
||||
}
|
||||
|
||||
// SavePreferred2FAMethod mocks base method.
|
||||
func (m *MockProvider) SavePreferred2FAMethod(arg0 context.Context, arg1, arg2 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SavePreferred2FAMethod", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SavePreferred2FAMethod indicates an expected call of SavePreferred2FAMethod.
|
||||
func (mr *MockProviderMockRecorder) SavePreferred2FAMethod(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePreferred2FAMethod", reflect.TypeOf((*MockProvider)(nil).SavePreferred2FAMethod), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SaveTOTPConfiguration mocks base method.
|
||||
func (m *MockProvider) SaveTOTPConfiguration(arg0 context.Context, arg1 models.TOTPConfiguration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveTOTPConfiguration", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveTOTPConfiguration indicates an expected call of SaveTOTPConfiguration.
|
||||
func (mr *MockProviderMockRecorder) SaveTOTPConfiguration(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTOTPConfiguration", reflect.TypeOf((*MockProvider)(nil).SaveTOTPConfiguration), arg0, arg1)
|
||||
}
|
||||
|
||||
// SaveU2FDevice mocks base method.
|
||||
func (m *MockProvider) SaveU2FDevice(arg0 context.Context, arg1 models.U2FDevice) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveU2FDevice", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveU2FDevice indicates an expected call of SaveU2FDevice.
|
||||
func (mr *MockProviderMockRecorder) SaveU2FDevice(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveU2FDevice", reflect.TypeOf((*MockProvider)(nil).SaveU2FDevice), arg0, arg1)
|
||||
}
|
||||
|
||||
// SchemaLatestVersion mocks base method.
|
||||
func (m *MockProvider) SchemaLatestVersion() (int, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SchemaLatestVersion")
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SchemaLatestVersion indicates an expected call of SchemaLatestVersion.
|
||||
func (mr *MockProviderMockRecorder) SchemaLatestVersion() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaLatestVersion", reflect.TypeOf((*MockProvider)(nil).SchemaLatestVersion))
|
||||
}
|
||||
|
||||
// SchemaMigrate mocks base method.
|
||||
func (m *MockProvider) SchemaMigrate(arg0 context.Context, arg1 bool, arg2 int) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SchemaMigrate", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SchemaMigrate indicates an expected call of SchemaMigrate.
|
||||
func (mr *MockProviderMockRecorder) SchemaMigrate(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrate", reflect.TypeOf((*MockProvider)(nil).SchemaMigrate), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SchemaMigrationHistory mocks base method.
|
||||
func (m *MockProvider) SchemaMigrationHistory(arg0 context.Context) ([]models.Migration, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SchemaMigrationHistory", arg0)
|
||||
ret0, _ := ret[0].([]models.Migration)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SchemaMigrationHistory indicates an expected call of SchemaMigrationHistory.
|
||||
func (mr *MockProviderMockRecorder) SchemaMigrationHistory(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationHistory", reflect.TypeOf((*MockProvider)(nil).SchemaMigrationHistory), arg0)
|
||||
}
|
||||
|
||||
// SchemaMigrationsDown mocks base method.
|
||||
func (m *MockProvider) SchemaMigrationsDown(arg0 context.Context, arg1 int) ([]SchemaMigration, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SchemaMigrationsDown", arg0, arg1)
|
||||
ret0, _ := ret[0].([]SchemaMigration)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SchemaMigrationsDown indicates an expected call of SchemaMigrationsDown.
|
||||
func (mr *MockProviderMockRecorder) SchemaMigrationsDown(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationsDown", reflect.TypeOf((*MockProvider)(nil).SchemaMigrationsDown), arg0, arg1)
|
||||
}
|
||||
|
||||
// SchemaMigrationsUp mocks base method.
|
||||
func (m *MockProvider) SchemaMigrationsUp(arg0 context.Context, arg1 int) ([]SchemaMigration, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SchemaMigrationsUp", arg0, arg1)
|
||||
ret0, _ := ret[0].([]SchemaMigration)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SchemaMigrationsUp indicates an expected call of SchemaMigrationsUp.
|
||||
func (mr *MockProviderMockRecorder) SchemaMigrationsUp(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaMigrationsUp", reflect.TypeOf((*MockProvider)(nil).SchemaMigrationsUp), arg0, arg1)
|
||||
}
|
||||
|
||||
// SchemaTables mocks base method.
|
||||
func (m *MockProvider) SchemaTables(arg0 context.Context) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SchemaTables", arg0)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SchemaTables indicates an expected call of SchemaTables.
|
||||
func (mr *MockProviderMockRecorder) SchemaTables(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaTables", reflect.TypeOf((*MockProvider)(nil).SchemaTables), arg0)
|
||||
}
|
||||
|
||||
// SchemaVersion mocks base method.
|
||||
func (m *MockProvider) SchemaVersion(arg0 context.Context) (int, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SchemaVersion", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SchemaVersion indicates an expected call of SchemaVersion.
|
||||
func (mr *MockProviderMockRecorder) SchemaVersion(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SchemaVersion", reflect.TypeOf((*MockProvider)(nil).SchemaVersion), arg0)
|
||||
}
|
||||
|
||||
// StartupCheck mocks base method.
|
||||
func (m *MockProvider) StartupCheck() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StartupCheck")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StartupCheck indicates an expected call of StartupCheck.
|
||||
func (mr *MockProviderMockRecorder) StartupCheck() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockProvider)(nil).StartupCheck))
|
||||
}
|
||||
|
|
|
@ -1,173 +1,199 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
"github.com/authelia/authelia/v4/internal/logging"
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
// NewSQLProvider generates a generic SQLProvider to be used with other SQL provider NewUp's.
|
||||
func NewSQLProvider(name, driverName, dataSourceName string) (provider SQLProvider) {
|
||||
db, err := sqlx.Open(driverName, dataSourceName)
|
||||
|
||||
provider = SQLProvider{
|
||||
name: name,
|
||||
driverName: driverName,
|
||||
db: db,
|
||||
log: logging.Logger(),
|
||||
errOpen: err,
|
||||
|
||||
sqlInsertAuthenticationAttempt: fmt.Sprintf(queryFmtInsertAuthenticationLogEntry, tableAuthenticationLogs),
|
||||
sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
|
||||
|
||||
sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification),
|
||||
sqlDeleteIdentityVerification: fmt.Sprintf(queryFmtDeleteIdentityVerification, tableIdentityVerification),
|
||||
sqlSelectExistsIdentityVerification: fmt.Sprintf(queryFmtSelectExistsIdentityVerification, tableIdentityVerification),
|
||||
|
||||
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
|
||||
sqlDeleteTOTPConfig: fmt.Sprintf(queryFmtDeleteTOTPConfiguration, tableTOTPConfigurations),
|
||||
sqlSelectTOTPConfig: fmt.Sprintf(queryFmtSelectTOTPConfiguration, tableTOTPConfigurations),
|
||||
|
||||
sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices),
|
||||
sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices),
|
||||
|
||||
sqlUpsertPreferred2FAMethod: fmt.Sprintf(queryFmtUpsertPreferred2FAMethod, tableUserPreferences),
|
||||
sqlSelectPreferred2FAMethod: fmt.Sprintf(queryFmtSelectPreferred2FAMethod, tableUserPreferences),
|
||||
sqlSelectUserInfo: fmt.Sprintf(queryFmtSelectUserInfo, tableTOTPConfigurations, tableU2FDevices, tableUserPreferences),
|
||||
|
||||
sqlInsertMigration: fmt.Sprintf(queryFmtInsertMigration, tableMigrations),
|
||||
sqlSelectMigrations: fmt.Sprintf(queryFmtSelectMigrations, tableMigrations),
|
||||
sqlSelectLatestMigration: fmt.Sprintf(queryFmtSelectLatestMigration, tableMigrations),
|
||||
|
||||
sqlFmtRenameTable: queryFmtRenameTable,
|
||||
}
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
// SQLProvider is a storage provider persisting data in a SQL database.
|
||||
type SQLProvider struct {
|
||||
db *sql.DB
|
||||
log *logrus.Logger
|
||||
name string
|
||||
db *sqlx.DB
|
||||
log *logrus.Logger
|
||||
name string
|
||||
driverName string
|
||||
errOpen error
|
||||
|
||||
sqlUpgradesCreateTableStatements map[SchemaVersion]map[string]string
|
||||
sqlUpgradesCreateTableIndexesStatements map[SchemaVersion][]string
|
||||
// Table: authentication_logs.
|
||||
sqlInsertAuthenticationAttempt string
|
||||
sqlSelectAuthenticationAttemptsByUsername string
|
||||
|
||||
sqlGetPreferencesByUsername string
|
||||
sqlUpsertSecondFactorPreference string
|
||||
// Table: identity_verification_tokens.
|
||||
sqlInsertIdentityVerification string
|
||||
sqlDeleteIdentityVerification string
|
||||
sqlSelectExistsIdentityVerification string
|
||||
|
||||
sqlTestIdentityVerificationTokenExistence string
|
||||
sqlInsertIdentityVerificationToken string
|
||||
sqlDeleteIdentityVerificationToken string
|
||||
// Table: totp_configurations.
|
||||
sqlUpsertTOTPConfig string
|
||||
sqlDeleteTOTPConfig string
|
||||
sqlSelectTOTPConfig string
|
||||
|
||||
sqlGetTOTPSecretByUsername string
|
||||
sqlUpsertTOTPSecret string
|
||||
sqlDeleteTOTPSecret string
|
||||
// Table: u2f_devices.
|
||||
sqlUpsertU2FDevice string
|
||||
sqlSelectU2FDevice string
|
||||
|
||||
sqlGetU2FDeviceHandleByUsername string
|
||||
sqlUpsertU2FDeviceHandle string
|
||||
// Table: user_preferences.
|
||||
sqlUpsertPreferred2FAMethod string
|
||||
sqlSelectPreferred2FAMethod string
|
||||
sqlSelectUserInfo string
|
||||
|
||||
sqlInsertAuthenticationLog string
|
||||
sqlGetLatestAuthenticationLogs string
|
||||
// Table: migrations.
|
||||
sqlInsertMigration string
|
||||
sqlSelectMigrations string
|
||||
sqlSelectLatestMigration string
|
||||
|
||||
sqlGetExistingTables string
|
||||
|
||||
sqlConfigSetValue string
|
||||
sqlConfigGetValue string
|
||||
// Utility.
|
||||
sqlSelectExistingTables string
|
||||
sqlFmtRenameTable string
|
||||
}
|
||||
|
||||
func (p *SQLProvider) initialize(db *sql.DB) error {
|
||||
p.db = db
|
||||
p.log = logging.Logger()
|
||||
|
||||
return p.upgrade()
|
||||
}
|
||||
|
||||
func (p *SQLProvider) getSchemaBasicDetails() (version SchemaVersion, tables []string, err error) {
|
||||
rows, err := p.db.Query(p.sqlGetExistingTables)
|
||||
if err != nil {
|
||||
return version, tables, err
|
||||
// StartupCheck implements the provider startup check interface.
|
||||
func (p *SQLProvider) StartupCheck() (err error) {
|
||||
if p.errOpen != nil {
|
||||
return p.errOpen
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var table string
|
||||
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&table)
|
||||
if err != nil {
|
||||
return version, tables, err
|
||||
// TODO: Decide if this is needed, or if it should be configurable.
|
||||
for i := 0; i < 19; i++ {
|
||||
err = p.db.Ping()
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
tables = append(tables, table)
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
}
|
||||
|
||||
if utils.IsStringInSlice(configTableName, tables) {
|
||||
rows, err := p.db.Query(p.sqlConfigGetValue, "schema", "version")
|
||||
if err != nil {
|
||||
return version, tables, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&version)
|
||||
if err != nil {
|
||||
return version, tables, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return version, tables, nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) upgrade() error {
|
||||
p.log.Debug("Storage schema is being checked to verify it is up to date")
|
||||
|
||||
version, tables, err := p.getSchemaBasicDetails()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if version < storageSchemaCurrentVersion {
|
||||
p.log.Debugf("Storage schema is v%d, latest is v%d", version, storageSchemaCurrentVersion)
|
||||
p.log.Infof("Storage schema is being checked for updates")
|
||||
|
||||
tx, err := p.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
switch version {
|
||||
case 0:
|
||||
err := p.upgradeSchemaToVersion001(tx, tables)
|
||||
if err != nil {
|
||||
return p.handleUpgradeFailure(tx, 1, err)
|
||||
}
|
||||
err = p.SchemaMigrate(ctx, true, SchemaLatest)
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
err := tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.log.Infof("Storage schema upgrade to v%d completed", storageSchemaCurrentVersion)
|
||||
}
|
||||
} else {
|
||||
p.log.Debug("Storage schema is up to date")
|
||||
switch err {
|
||||
case ErrSchemaAlreadyUpToDate:
|
||||
p.log.Infof("Storage schema is already up to date")
|
||||
return nil
|
||||
case nil:
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) handleUpgradeFailure(tx *sql.Tx, version SchemaVersion, err error) error {
|
||||
rollbackErr := tx.Rollback()
|
||||
formattedErr := fmt.Errorf("%s%d: %v", storageSchemaUpgradeErrorText, version, err)
|
||||
|
||||
if rollbackErr != nil {
|
||||
return fmt.Errorf("rollback error occurred: %v (inner error %v)", rollbackErr, formattedErr)
|
||||
}
|
||||
|
||||
return formattedErr
|
||||
}
|
||||
|
||||
// LoadPreferred2FAMethod load the preferred method for 2FA from the database.
|
||||
func (p *SQLProvider) LoadPreferred2FAMethod(username string) (string, error) {
|
||||
var method string
|
||||
|
||||
rows, err := p.db.Query(p.sqlGetPreferencesByUsername, username)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
err = rows.Scan(&method)
|
||||
|
||||
return method, err
|
||||
}
|
||||
|
||||
// SavePreferred2FAMethod save the preferred method for 2FA to the database.
|
||||
func (p *SQLProvider) SavePreferred2FAMethod(username string, method string) error {
|
||||
_, err := p.db.Exec(p.sqlUpsertSecondFactorPreference, username, method)
|
||||
func (p *SQLProvider) SavePreferred2FAMethod(ctx context.Context, username string, method string) (err error) {
|
||||
_, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, method)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// FindIdentityVerificationToken look for an identity verification token in the database.
|
||||
func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error) {
|
||||
var found bool
|
||||
// LoadPreferred2FAMethod load the preferred method for 2FA from the database.
|
||||
func (p *SQLProvider) LoadPreferred2FAMethod(ctx context.Context, username string) (method string, err error) {
|
||||
err = p.db.GetContext(ctx, &method, p.sqlSelectPreferred2FAMethod, username)
|
||||
|
||||
err := p.db.QueryRow(p.sqlTestIdentityVerificationTokenExistence, token).Scan(&found)
|
||||
switch err {
|
||||
case sql.ErrNoRows:
|
||||
return "", nil
|
||||
case nil:
|
||||
return method, err
|
||||
default:
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// LoadUserInfo loads the models.UserInfo from the database.
|
||||
func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error) {
|
||||
err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username)
|
||||
|
||||
switch {
|
||||
case err == nil:
|
||||
return info, nil
|
||||
case errors.Is(err, sql.ErrNoRows):
|
||||
_, err = p.db.ExecContext(ctx, p.sqlUpsertPreferred2FAMethod, username, authentication.PossibleMethods[0])
|
||||
if err != nil {
|
||||
return models.UserInfo{}, err
|
||||
}
|
||||
|
||||
err = p.db.GetContext(ctx, &info, p.sqlSelectUserInfo, username, username, username)
|
||||
if err != nil {
|
||||
return models.UserInfo{}, err
|
||||
}
|
||||
|
||||
return info, nil
|
||||
default:
|
||||
return models.UserInfo{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// SaveIdentityVerification save an identity verification record to the database.
|
||||
func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) {
|
||||
_, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification, verification.Token)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveIdentityVerification remove an identity verification record from the database.
|
||||
func (p *SQLProvider) RemoveIdentityVerification(ctx context.Context, token string) (err error) {
|
||||
_, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, token)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// FindIdentityVerification checks if an identity verification record is in the database and active.
|
||||
func (p *SQLProvider) FindIdentityVerification(ctx context.Context, jti string) (found bool, err error) {
|
||||
err = p.db.GetContext(ctx, &found, p.sqlSelectExistsIdentityVerification, jti)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -175,105 +201,94 @@ func (p *SQLProvider) FindIdentityVerificationToken(token string) (bool, error)
|
|||
return found, nil
|
||||
}
|
||||
|
||||
// SaveIdentityVerificationToken save an identity verification token in the database.
|
||||
func (p *SQLProvider) SaveIdentityVerificationToken(token string) error {
|
||||
_, err := p.db.Exec(p.sqlInsertIdentityVerificationToken, token)
|
||||
// SaveTOTPConfiguration save a TOTP config of a given user in the database.
|
||||
func (p *SQLProvider) SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error) {
|
||||
// TODO: Encrypt config.Secret here.
|
||||
_, err = p.db.ExecContext(ctx, p.sqlUpsertTOTPConfig,
|
||||
config.Username,
|
||||
config.Algorithm,
|
||||
config.Digits,
|
||||
config.Period,
|
||||
config.Secret,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveIdentityVerificationToken remove an identity verification token from the database.
|
||||
func (p *SQLProvider) RemoveIdentityVerificationToken(token string) error {
|
||||
_, err := p.db.Exec(p.sqlDeleteIdentityVerificationToken, token)
|
||||
// DeleteTOTPConfiguration delete a TOTP secret from the database given a username.
|
||||
func (p *SQLProvider) DeleteTOTPConfiguration(ctx context.Context, username string) (err error) {
|
||||
_, err = p.db.ExecContext(ctx, p.sqlDeleteTOTPConfig, username)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SaveTOTPSecret save a TOTP secret of a given user in the database.
|
||||
func (p *SQLProvider) SaveTOTPSecret(username string, secret string) error {
|
||||
_, err := p.db.Exec(p.sqlUpsertTOTPSecret, username, secret)
|
||||
return err
|
||||
}
|
||||
// LoadTOTPConfiguration load a TOTP secret given a username from the database.
|
||||
func (p *SQLProvider) LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error) {
|
||||
config = &models.TOTPConfiguration{}
|
||||
|
||||
// LoadTOTPSecret load a TOTP secret given a username from the database.
|
||||
func (p *SQLProvider) LoadTOTPSecret(username string) (string, error) {
|
||||
var secret string
|
||||
if err := p.db.QueryRow(p.sqlGetTOTPSecretByUsername, username).Scan(&secret); err != nil {
|
||||
err = p.db.QueryRowxContext(ctx, p.sqlSelectTOTPConfig, username).StructScan(config)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return "", ErrNoTOTPSecret
|
||||
return nil, ErrNoTOTPSecret
|
||||
}
|
||||
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return secret, nil
|
||||
// TODO: Decrypt config.Secret here.
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// DeleteTOTPSecret delete a TOTP secret from the database given a username.
|
||||
func (p *SQLProvider) DeleteTOTPSecret(username string) error {
|
||||
_, err := p.db.Exec(p.sqlDeleteTOTPSecret, username)
|
||||
return err
|
||||
}
|
||||
|
||||
// SaveU2FDeviceHandle save a registered U2F device registration blob.
|
||||
func (p *SQLProvider) SaveU2FDeviceHandle(username string, keyHandle []byte, publicKey []byte) error {
|
||||
_, err := p.db.Exec(p.sqlUpsertU2FDeviceHandle,
|
||||
username,
|
||||
base64.StdEncoding.EncodeToString(keyHandle),
|
||||
base64.StdEncoding.EncodeToString(publicKey))
|
||||
// SaveU2FDevice saves a registered U2F device.
|
||||
func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) {
|
||||
_, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.KeyHandle, device.PublicKey)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadU2FDeviceHandle load a U2F device registration blob for a given username.
|
||||
func (p *SQLProvider) LoadU2FDeviceHandle(username string) ([]byte, []byte, error) {
|
||||
var keyHandleBase64, publicKeyBase64 string
|
||||
if err := p.db.QueryRow(p.sqlGetU2FDeviceHandleByUsername, username).Scan(&keyHandleBase64, &publicKeyBase64); err != nil {
|
||||
// LoadU2FDevice loads a U2F device registration for a given username.
|
||||
func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error) {
|
||||
device = &models.U2FDevice{
|
||||
Username: username,
|
||||
}
|
||||
|
||||
err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, ErrNoU2FDeviceHandle
|
||||
return nil, ErrNoU2FDeviceHandle
|
||||
}
|
||||
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyHandle, err := base64.StdEncoding.DecodeString(keyHandleBase64)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
publicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return keyHandle, publicKey, nil
|
||||
return device, nil
|
||||
}
|
||||
|
||||
// AppendAuthenticationLog append a mark to the authentication log.
|
||||
func (p *SQLProvider) AppendAuthenticationLog(attempt models.AuthenticationAttempt) error {
|
||||
_, err := p.db.Exec(p.sqlInsertAuthenticationLog, attempt.Username, attempt.Successful, attempt.Time.Unix())
|
||||
func (p *SQLProvider) AppendAuthenticationLog(ctx context.Context, attempt models.AuthenticationAttempt) (err error) {
|
||||
_, err = p.db.ExecContext(ctx, p.sqlInsertAuthenticationAttempt, attempt.Time, attempt.Successful, attempt.Username)
|
||||
return err
|
||||
}
|
||||
|
||||
// LoadLatestAuthenticationLogs retrieve the latest marks from the authentication log.
|
||||
func (p *SQLProvider) LoadLatestAuthenticationLogs(username string, fromDate time.Time) ([]models.AuthenticationAttempt, error) {
|
||||
var t int64
|
||||
|
||||
rows, err := p.db.Query(p.sqlGetLatestAuthenticationLogs, fromDate.Unix(), username)
|
||||
|
||||
// LoadAuthenticationLogs retrieve the latest failed authentications from the authentication log.
|
||||
func (p *SQLProvider) LoadAuthenticationLogs(ctx context.Context, username string, fromDate time.Time, limit, page int) (attempts []models.AuthenticationAttempt, err error) {
|
||||
rows, err := p.db.QueryxContext(ctx, p.sqlSelectAuthenticationAttemptsByUsername, fromDate, username, limit, limit*page)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attempts := make([]models.AuthenticationAttempt, 0, 10)
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
p.log.Errorf(logFmtErrClosingConn, err)
|
||||
}
|
||||
}()
|
||||
|
||||
attempts = make([]models.AuthenticationAttempt, 0, limit)
|
||||
|
||||
for rows.Next() {
|
||||
attempt := models.AuthenticationAttempt{
|
||||
Username: username,
|
||||
}
|
||||
err = rows.Scan(&attempt.Successful, &t)
|
||||
attempt.Time = time.Unix(t, 0)
|
||||
var attempt models.AuthenticationAttempt
|
||||
|
||||
err = rows.StructScan(&attempt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
53
internal/storage/sql_provider_backend_mysql.go
Normal file
53
internal/storage/sql_provider_backend_mysql.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql" // Load the MySQL Driver used in the connection string.
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
||||
// MySQLProvider is a MySQL provider.
|
||||
type MySQLProvider struct {
|
||||
SQLProvider
|
||||
}
|
||||
|
||||
// NewMySQLProvider a MySQL provider.
|
||||
func NewMySQLProvider(config schema.MySQLStorageConfiguration) (provider *MySQLProvider) {
|
||||
provider = &MySQLProvider{
|
||||
SQLProvider: NewSQLProvider(providerMySQL, providerMySQL, dataSourceNameMySQL(config)),
|
||||
}
|
||||
|
||||
// All providers have differing SELECT existing table statements.
|
||||
provider.sqlSelectExistingTables = queryMySQLSelectExistingTables
|
||||
|
||||
// Specific alterations to this provider.
|
||||
provider.sqlFmtRenameTable = queryFmtMySQLRenameTable
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
func dataSourceNameMySQL(config schema.MySQLStorageConfiguration) (dataSourceName string) {
|
||||
dataSourceName = fmt.Sprintf("%s:%s", config.Username, config.Password)
|
||||
|
||||
if dataSourceName != "" {
|
||||
dataSourceName += "@"
|
||||
}
|
||||
|
||||
address := config.Host
|
||||
if config.Port > 0 {
|
||||
address += fmt.Sprintf(":%d", config.Port)
|
||||
}
|
||||
|
||||
dataSourceName += fmt.Sprintf("tcp(%s)", address)
|
||||
if config.Database != "" {
|
||||
dataSourceName += fmt.Sprintf("/%s", config.Database)
|
||||
}
|
||||
|
||||
dataSourceName += "?"
|
||||
dataSourceName += fmt.Sprintf("timeout=%ds&multiStatements=true&parseTime=true", int32(config.Timeout/time.Second))
|
||||
|
||||
return dataSourceName
|
||||
}
|
72
internal/storage/sql_provider_backend_postgres.go
Normal file
72
internal/storage/sql_provider_backend_postgres.go
Normal file
|
@ -0,0 +1,72 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v4/stdlib" // Load the PostgreSQL Driver used in the connection string.
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/configuration/schema"
|
||||
)
|
||||
|
||||
// PostgreSQLProvider is a PostgreSQL provider.
|
||||
type PostgreSQLProvider struct {
|
||||
SQLProvider
|
||||
}
|
||||
|
||||
// NewPostgreSQLProvider a PostgreSQL provider.
|
||||
func NewPostgreSQLProvider(config schema.PostgreSQLStorageConfiguration) (provider *PostgreSQLProvider) {
|
||||
provider = &PostgreSQLProvider{
|
||||
SQLProvider: NewSQLProvider(providerPostgres, "pgx", dataSourceNamePostgreSQL(config)),
|
||||
}
|
||||
|
||||
// All providers have differing SELECT existing table statements.
|
||||
provider.sqlSelectExistingTables = queryPostgreSelectExistingTables
|
||||
|
||||
// Specific alterations to this provider.
|
||||
// PostgreSQL doesn't have a UPSERT statement but has an ON CONFLICT operation instead.
|
||||
provider.sqlUpsertU2FDevice = fmt.Sprintf(queryFmtPostgresUpsertU2FDevice, tableU2FDevices)
|
||||
provider.sqlUpsertTOTPConfig = fmt.Sprintf(queryFmtPostgresUpsertTOTPConfiguration, tableTOTPConfigurations)
|
||||
provider.sqlUpsertPreferred2FAMethod = fmt.Sprintf(queryFmtPostgresUpsertPreferred2FAMethod, tableUserPreferences)
|
||||
|
||||
// PostgreSQL requires rebinding of any query that contains a '?' placeholder to use the '$#' notation placeholders.
|
||||
provider.sqlFmtRenameTable = provider.db.Rebind(provider.sqlFmtRenameTable)
|
||||
provider.sqlSelectPreferred2FAMethod = provider.db.Rebind(provider.sqlSelectPreferred2FAMethod)
|
||||
provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo)
|
||||
provider.sqlSelectExistsIdentityVerification = provider.db.Rebind(provider.sqlSelectExistsIdentityVerification)
|
||||
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
|
||||
provider.sqlDeleteIdentityVerification = provider.db.Rebind(provider.sqlDeleteIdentityVerification)
|
||||
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
|
||||
provider.sqlUpsertTOTPConfig = provider.db.Rebind(provider.sqlUpsertTOTPConfig)
|
||||
provider.sqlDeleteTOTPConfig = provider.db.Rebind(provider.sqlDeleteTOTPConfig)
|
||||
provider.sqlSelectU2FDevice = provider.db.Rebind(provider.sqlSelectU2FDevice)
|
||||
provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt)
|
||||
provider.sqlSelectAuthenticationAttemptsByUsername = provider.db.Rebind(provider.sqlSelectAuthenticationAttemptsByUsername)
|
||||
provider.sqlInsertMigration = provider.db.Rebind(provider.sqlInsertMigration)
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
func dataSourceNamePostgreSQL(config schema.PostgreSQLStorageConfiguration) (dataSourceName string) {
|
||||
args := []string{
|
||||
fmt.Sprintf("user='%s'", config.Username),
|
||||
fmt.Sprintf("password='%s'", config.Password),
|
||||
}
|
||||
|
||||
if config.Host != "" {
|
||||
args = append(args, fmt.Sprintf("host=%s", config.Host))
|
||||
}
|
||||
|
||||
if config.Port > 0 {
|
||||
args = append(args, fmt.Sprintf("port=%d", config.Port))
|
||||
}
|
||||
|
||||
if config.Database != "" {
|
||||
args = append(args, fmt.Sprintf("dbname=%s", config.Database))
|
||||
}
|
||||
|
||||
args = append(args, fmt.Sprintf("connect_timeout=%d", int32(config.Timeout/time.Second)))
|
||||
|
||||
return strings.Join(args, " ")
|
||||
}
|
22
internal/storage/sql_provider_backend_sqlite.go
Normal file
22
internal/storage/sql_provider_backend_sqlite.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
_ "github.com/mattn/go-sqlite3" // Load the SQLite Driver used in the connection string.
|
||||
)
|
||||
|
||||
// SQLiteProvider is a SQLite3 provider.
|
||||
type SQLiteProvider struct {
|
||||
SQLProvider
|
||||
}
|
||||
|
||||
// NewSQLiteProvider constructs a SQLite provider.
|
||||
func NewSQLiteProvider(path string) (provider *SQLiteProvider) {
|
||||
provider = &SQLiteProvider{
|
||||
SQLProvider: NewSQLProvider(providerSQLite, "sqlite3", path),
|
||||
}
|
||||
|
||||
// All providers have differing SELECT existing table statements.
|
||||
provider.sqlSelectExistingTables = querySQLiteSelectExistingTables
|
||||
|
||||
return provider
|
||||
}
|
125
internal/storage/sql_provider_queries.go
Normal file
125
internal/storage/sql_provider_queries.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package storage
|
||||
|
||||
const (
|
||||
queryFmtSelectMigrations = `
|
||||
SELECT id, applied, version_before, version_after, application_version
|
||||
FROM %s;`
|
||||
|
||||
queryFmtSelectLatestMigration = `
|
||||
SELECT id, applied, version_before, version_after, application_version
|
||||
FROM %s
|
||||
ORDER BY id DESC
|
||||
LIMIT 1;`
|
||||
|
||||
queryFmtInsertMigration = `
|
||||
INSERT INTO %s (applied, version_before, version_after, application_version)
|
||||
VALUES (?, ?, ?, ?);`
|
||||
)
|
||||
|
||||
const (
|
||||
queryMySQLSelectExistingTables = `
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_type = 'BASE TABLE' AND table_schema = database();`
|
||||
|
||||
queryPostgreSelectExistingTables = `
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_type = 'BASE TABLE' AND table_schema = 'public';`
|
||||
|
||||
querySQLiteSelectExistingTables = `
|
||||
SELECT name
|
||||
FROM sqlite_master
|
||||
WHERE type = 'table';`
|
||||
)
|
||||
|
||||
const (
|
||||
queryFmtSelectUserInfo = `
|
||||
SELECT second_factor_method, (SELECT EXISTS (SELECT id FROM %s WHERE username = ?)) AS has_totp, (SELECT EXISTS (SELECT id FROM %s WHERE username = ?)) AS has_u2f
|
||||
FROM %s
|
||||
WHERE username = ?;`
|
||||
|
||||
queryFmtSelectPreferred2FAMethod = `
|
||||
SELECT second_factor_method
|
||||
FROM %s
|
||||
WHERE username = ?;`
|
||||
|
||||
queryFmtUpsertPreferred2FAMethod = `
|
||||
REPLACE INTO %s (username, second_factor_method)
|
||||
VALUES (?, ?);`
|
||||
|
||||
queryFmtPostgresUpsertPreferred2FAMethod = `
|
||||
INSERT INTO %s (username, second_factor_method)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (username)
|
||||
DO UPDATE SET second_factor_method = $2;`
|
||||
)
|
||||
|
||||
const (
|
||||
queryFmtSelectExistsIdentityVerification = `
|
||||
SELECT EXISTS (
|
||||
SELECT id
|
||||
FROM %s
|
||||
WHERE token = ?
|
||||
);`
|
||||
|
||||
queryFmtInsertIdentityVerification = `
|
||||
INSERT INTO %s (token)
|
||||
VALUES (?);`
|
||||
|
||||
queryFmtDeleteIdentityVerification = `
|
||||
DELETE FROM %s
|
||||
WHERE token = ?;`
|
||||
)
|
||||
|
||||
const (
|
||||
queryFmtSelectTOTPConfiguration = `
|
||||
SELECT id, username, algorithm, digits, totp_period, secret
|
||||
FROM %s
|
||||
WHERE username = ?;`
|
||||
|
||||
queryFmtUpsertTOTPConfiguration = `
|
||||
REPLACE INTO %s (username, algorithm, digits, totp_period, secret)
|
||||
VALUES (?, ?, ?, ?, ?);`
|
||||
|
||||
queryFmtPostgresUpsertTOTPConfiguration = `
|
||||
INSERT INTO %s (username, algorithm, digits, totp_period, secret)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (username)
|
||||
DO UPDATE SET algorithm = $2, digits = $3, totp_period = $4, secret = $5;`
|
||||
|
||||
queryFmtDeleteTOTPConfiguration = `
|
||||
DELETE FROM %s
|
||||
WHERE username = ?;`
|
||||
)
|
||||
|
||||
const (
|
||||
queryFmtSelectU2FDevice = `
|
||||
SELECT key_handle, public_key
|
||||
FROM %s
|
||||
WHERE username = ?;`
|
||||
|
||||
queryFmtUpsertU2FDevice = `
|
||||
REPLACE INTO %s (username, key_handle, public_key)
|
||||
VALUES (?, ?, ?);`
|
||||
|
||||
queryFmtPostgresUpsertU2FDevice = `
|
||||
INSERT INTO %s (username, key_handle, public_key)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (username)
|
||||
DO UPDATE SET key_handle=$2, public_key=$3;`
|
||||
)
|
||||
|
||||
const (
|
||||
queryFmtInsertAuthenticationLogEntry = `
|
||||
INSERT INTO %s (time, successful, username)
|
||||
VALUES (?, ?, ?);`
|
||||
|
||||
queryFmtSelect1FAAuthenticationLogEntryByUsername = `
|
||||
SELECT time, successful, username
|
||||
FROM %s
|
||||
WHERE time > ? AND username = ?
|
||||
ORDER BY time DESC
|
||||
LIMIT ?
|
||||
OFFSET ?;`
|
||||
)
|
109
internal/storage/sql_provider_queries_special.go
Normal file
109
internal/storage/sql_provider_queries_special.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package storage
|
||||
|
||||
const (
|
||||
queryFmtDropTableIfExists = `DROP TABLE IF EXISTS %s;`
|
||||
|
||||
queryFmtRenameTable = `
|
||||
ALTER TABLE %s
|
||||
RENAME TO %s;`
|
||||
|
||||
queryFmtMySQLRenameTable = `
|
||||
ALTER TABLE %s
|
||||
RENAME %s;`
|
||||
)
|
||||
|
||||
// Pre1 migration constants.
|
||||
const (
|
||||
queryFmtPre1To1SelectAuthenticationLogs = `
|
||||
SELECT username, successful, time
|
||||
FROM %s
|
||||
ORDER BY time ASC
|
||||
LIMIT 100 OFFSET ?;`
|
||||
|
||||
queryFmtPre1To1InsertAuthenticationLogs = `
|
||||
INSERT INTO %s (username, successful, time)
|
||||
VALUES (?, ?, ?);`
|
||||
|
||||
queryFmtPre1InsertUserPreferencesFromSelect = `
|
||||
INSERT INTO %s (username, second_factor_method)
|
||||
SELECT username, second_factor_method
|
||||
FROM %s
|
||||
ORDER BY username ASC;`
|
||||
|
||||
queryFmtPre1SelectTOTPConfigurations = `
|
||||
SELECT username, secret
|
||||
FROM %s
|
||||
ORDER BY username ASC;`
|
||||
|
||||
queryFmtPre1InsertTOTPConfiguration = `
|
||||
INSERT INTO %s (username, secret)
|
||||
VALUES (?, ?);`
|
||||
|
||||
queryFmtPre1To1SelectU2FDevices = `
|
||||
SELECT username, keyHandle, publicKey
|
||||
FROM %s
|
||||
ORDER BY username ASC;`
|
||||
|
||||
queryFmtPre1To1InsertU2FDevice = `
|
||||
INSERT INTO %s (username, key_handle, public_key)
|
||||
VALUES (?, ?, ?);`
|
||||
|
||||
queryFmt1ToPre1InsertAuthenticationLogs = `
|
||||
INSERT INTO %s (username, successful, time)
|
||||
VALUES (?, ?, ?);`
|
||||
|
||||
queryFmt1ToPre1SelectAuthenticationLogs = `
|
||||
SELECT username, successful, time
|
||||
FROM %s
|
||||
ORDER BY id ASC
|
||||
LIMIT 100 OFFSET ?;`
|
||||
|
||||
queryFmt1ToPre1SelectU2FDevices = `
|
||||
SELECT username, key_handle, public_key
|
||||
FROM %s
|
||||
ORDER BY username ASC;`
|
||||
|
||||
queryFmt1ToPre1InsertU2FDevice = `
|
||||
INSERT INTO %s (username, keyHandle, publicKey)
|
||||
VALUES (?, ?, ?);`
|
||||
|
||||
queryCreatePre1 = `
|
||||
CREATE TABLE user_preferences (
|
||||
username VARCHAR(100),
|
||||
second_factor_method VARCHAR(11),
|
||||
PRIMARY KEY (username)
|
||||
);
|
||||
|
||||
CREATE TABLE identity_verification_tokens (
|
||||
token VARCHAR(512)
|
||||
);
|
||||
|
||||
CREATE TABLE totp_secrets (
|
||||
username VARCHAR(100),
|
||||
secret VARCHAR(64),
|
||||
PRIMARY KEY (username)
|
||||
);
|
||||
|
||||
CREATE TABLE u2f_devices (
|
||||
username VARCHAR(100),
|
||||
keyHandle TEXT,
|
||||
publicKey TEXT,
|
||||
PRIMARY KEY (username)
|
||||
);
|
||||
|
||||
CREATE TABLE authentication_logs (
|
||||
username VARCHAR(100),
|
||||
successful BOOL,
|
||||
time INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE config (
|
||||
category VARCHAR(32) NOT NULL,
|
||||
key_name VARCHAR(32) NOT NULL,
|
||||
value TEXT,
|
||||
PRIMARY KEY (category, key_name)
|
||||
);
|
||||
|
||||
INSERT INTO config (category, key_name, value)
|
||||
VALUES ('schema', 'version', '1');`
|
||||
)
|
327
internal/storage/sql_provider_schema.go
Normal file
327
internal/storage/sql_provider_schema.go
Normal file
|
@ -0,0 +1,327 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
// SchemaTables returns a list of tables.
|
||||
func (p *SQLProvider) SchemaTables(ctx context.Context) (tables []string, err error) {
|
||||
rows, err := p.db.QueryxContext(ctx, p.sqlSelectExistingTables)
|
||||
if err != nil {
|
||||
return tables, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
p.log.Errorf(logFmtErrClosingConn, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var table string
|
||||
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&table)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
// SchemaVersion returns the version of the schema.
|
||||
func (p *SQLProvider) SchemaVersion(ctx context.Context) (version int, err error) {
|
||||
tables, err := p.SchemaTables(ctx)
|
||||
if err != nil {
|
||||
return -2, err
|
||||
}
|
||||
|
||||
if len(tables) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if utils.IsStringInSlice(tableMigrations, tables) {
|
||||
migration, err := p.schemaLatestMigration(ctx)
|
||||
if err != nil {
|
||||
return -2, err
|
||||
}
|
||||
|
||||
return migration.After, nil
|
||||
}
|
||||
|
||||
if utils.IsStringInSlice(tableUserPreferences, tables) && utils.IsStringInSlice(tablePre1TOTPSecrets, tables) &&
|
||||
utils.IsStringInSlice(tableU2FDevices, tables) && utils.IsStringInSlice(tableAuthenticationLogs, tables) &&
|
||||
utils.IsStringInSlice(tablePre1IdentityVerificationTokens, tables) && !utils.IsStringInSlice(tableMigrations, tables) {
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
// TODO: Decide if we want to support external tables.
|
||||
// return -2, ErrUnknownSchemaState
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaLatestMigration(ctx context.Context) (migration *models.Migration, err error) {
|
||||
migration = &models.Migration{}
|
||||
|
||||
err = p.db.QueryRowxContext(ctx, p.sqlSelectLatestMigration).StructScan(migration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return migration, nil
|
||||
}
|
||||
|
||||
// SchemaMigrationHistory returns migration history rows.
|
||||
func (p *SQLProvider) SchemaMigrationHistory(ctx context.Context) (migrations []models.Migration, err error) {
|
||||
rows, err := p.db.QueryxContext(ctx, p.sqlSelectMigrations)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
p.log.Errorf(logFmtErrClosingConn, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var migration models.Migration
|
||||
|
||||
for rows.Next() {
|
||||
err = rows.StructScan(&migration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migrations = append(migrations, migration)
|
||||
}
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
// SchemaMigrate migrates from the current version to the provided version.
|
||||
func (p *SQLProvider) SchemaMigrate(ctx context.Context, up bool, version int) (err error) {
|
||||
currentVersion, err := p.SchemaVersion(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = schemaMigrateChecks(p.name, up, version, currentVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.schemaMigrate(ctx, currentVersion, version)
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigrate(ctx context.Context, prior, target int) (err error) {
|
||||
migrations, err := loadMigrations(p.name, prior, target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(migrations) == 0 {
|
||||
return ErrNoMigrationsFound
|
||||
}
|
||||
|
||||
switch {
|
||||
case prior == -1:
|
||||
p.log.Infof(logFmtMigrationFromTo, "pre1", strconv.Itoa(migrations[len(migrations)-1].After()))
|
||||
|
||||
err = p.schemaMigratePre1To1(ctx)
|
||||
if err != nil {
|
||||
if errRollback := p.schemaMigratePre1To1Rollback(ctx, true); errRollback != nil {
|
||||
return fmt.Errorf(errFmtFailedMigrationPre1, err)
|
||||
}
|
||||
|
||||
return fmt.Errorf(errFmtFailedMigrationPre1, err)
|
||||
}
|
||||
case target == -1:
|
||||
p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), "pre1")
|
||||
default:
|
||||
p.log.Infof(logFmtMigrationFromTo, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
if prior == -1 && migration.Version == 1 {
|
||||
// Skip migration version 1 when upgrading from pre1 as it's applied as part of the pre1 upgrade.
|
||||
continue
|
||||
}
|
||||
|
||||
err = p.schemaMigrateApply(ctx, migration)
|
||||
if err != nil {
|
||||
return p.schemaMigrateRollback(ctx, prior, migration.After(), err)
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case prior == -1:
|
||||
p.log.Infof(logFmtMigrationComplete, "pre1", strconv.Itoa(migrations[len(migrations)-1].After()))
|
||||
case target == -1:
|
||||
err = p.schemaMigrate1ToPre1(ctx)
|
||||
if err != nil {
|
||||
if errRollback := p.schemaMigratePre1To1Rollback(ctx, false); errRollback != nil {
|
||||
return fmt.Errorf(errFmtFailedMigrationPre1, err)
|
||||
}
|
||||
|
||||
return fmt.Errorf(errFmtFailedMigrationPre1, err)
|
||||
}
|
||||
|
||||
p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), "pre1")
|
||||
default:
|
||||
p.log.Infof(logFmtMigrationComplete, strconv.Itoa(prior), strconv.Itoa(migrations[len(migrations)-1].After()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigrateRollback(ctx context.Context, prior, after int, migrateErr error) (err error) {
|
||||
migrations, err := loadMigrations(p.name, after, prior)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading migrations from version %d to version %d for rollback: %+v. rollback caused by: %+v", prior, after, err, migrateErr)
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
if prior == -1 && !migration.Up && migration.Version == 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
err = p.schemaMigrateApply(ctx, migration)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error applying migration version %d to version %d for rollback: %+v. rollback caused by: %+v", migration.Before(), migration.After(), err, migrateErr)
|
||||
}
|
||||
}
|
||||
|
||||
if prior == -1 {
|
||||
if err = p.schemaMigrate1ToPre1(ctx); err != nil {
|
||||
return fmt.Errorf("error applying migration version 1 to version pre1 for rollback: %+v. rollback caused by: %+v", err, migrateErr)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("migration rollback complete. rollback caused by: %+v", migrateErr)
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigrateApply(ctx context.Context, migration SchemaMigration) (err error) {
|
||||
_, err = p.db.ExecContext(ctx, migration.Query)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
|
||||
}
|
||||
|
||||
// Skip the migration history insertion in a migration to v0.
|
||||
if migration.Version == 1 && !migration.Up {
|
||||
return nil
|
||||
}
|
||||
|
||||
return p.schemaMigrateFinalize(ctx, migration)
|
||||
}
|
||||
|
||||
func (p SQLProvider) schemaMigrateFinalize(ctx context.Context, migration SchemaMigration) (err error) {
|
||||
return p.schemaMigrateFinalizeAdvanced(ctx, migration.Before(), migration.After())
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigrateFinalizeAdvanced(ctx context.Context, before, after int) (err error) {
|
||||
_, err = p.db.ExecContext(ctx, p.sqlInsertMigration, time.Now(), before, after, utils.Version())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.log.Debugf("Storage schema migrated from version %d to %d", before, after)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SchemaMigrationsUp returns a list of migrations up available between the current version and the provided version.
|
||||
func (p *SQLProvider) SchemaMigrationsUp(ctx context.Context, version int) (migrations []SchemaMigration, err error) {
|
||||
current, err := p.SchemaVersion(ctx)
|
||||
if err != nil {
|
||||
return migrations, err
|
||||
}
|
||||
|
||||
if version == 0 {
|
||||
version = SchemaLatest
|
||||
}
|
||||
|
||||
if current >= version {
|
||||
return migrations, ErrNoAvailableMigrations
|
||||
}
|
||||
|
||||
return loadMigrations(p.name, current, version)
|
||||
}
|
||||
|
||||
// SchemaMigrationsDown returns a list of migrations down available between the current version and the provided version.
|
||||
func (p *SQLProvider) SchemaMigrationsDown(ctx context.Context, version int) (migrations []SchemaMigration, err error) {
|
||||
current, err := p.SchemaVersion(ctx)
|
||||
if err != nil {
|
||||
return migrations, err
|
||||
}
|
||||
|
||||
if current <= version {
|
||||
return migrations, ErrNoAvailableMigrations
|
||||
}
|
||||
|
||||
return loadMigrations(p.name, current, version)
|
||||
}
|
||||
|
||||
// SchemaLatestVersion returns the latest version available for migration..
|
||||
func (p *SQLProvider) SchemaLatestVersion() (version int, err error) {
|
||||
return latestMigrationVersion(p.name)
|
||||
}
|
||||
|
||||
func schemaMigrateChecks(providerName string, up bool, targetVersion, currentVersion int) (err error) {
|
||||
if targetVersion == currentVersion {
|
||||
return fmt.Errorf(ErrFmtMigrateAlreadyOnTargetVersion, targetVersion, currentVersion)
|
||||
}
|
||||
|
||||
latest, err := latestMigrationVersion(providerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if currentVersion > latest {
|
||||
return fmt.Errorf(errFmtSchemaCurrentGreaterThanLatestKnown, latest)
|
||||
}
|
||||
|
||||
if up {
|
||||
if targetVersion < currentVersion {
|
||||
return fmt.Errorf(ErrFmtMigrateUpTargetLessThanCurrent, targetVersion, currentVersion)
|
||||
}
|
||||
|
||||
if targetVersion == SchemaLatest && latest == currentVersion {
|
||||
return ErrSchemaAlreadyUpToDate
|
||||
}
|
||||
|
||||
if targetVersion != SchemaLatest && latest < targetVersion {
|
||||
return fmt.Errorf(ErrFmtMigrateUpTargetGreaterThanLatest, targetVersion, latest)
|
||||
}
|
||||
} else {
|
||||
if targetVersion < -1 {
|
||||
return fmt.Errorf(ErrFmtMigrateDownTargetLessThanMinimum, targetVersion)
|
||||
}
|
||||
|
||||
if targetVersion > currentVersion {
|
||||
return fmt.Errorf(ErrFmtMigrateDownTargetGreaterThanCurrent, targetVersion, currentVersion)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SchemaVersionToString returns a version string given a version number.
|
||||
func SchemaVersionToString(version int) (versionStr string) {
|
||||
switch version {
|
||||
case -2:
|
||||
return "unknown"
|
||||
case -1:
|
||||
return "pre1"
|
||||
case 0:
|
||||
return "N/A"
|
||||
default:
|
||||
return strconv.Itoa(version)
|
||||
}
|
||||
}
|
449
internal/storage/sql_provider_schema_pre1.go
Normal file
449
internal/storage/sql_provider_schema_pre1.go
Normal file
|
@ -0,0 +1,449 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
// schemaMigratePre1To1 takes the v1 migration and migrates to this version.
|
||||
func (p *SQLProvider) schemaMigratePre1To1(ctx context.Context) (err error) {
|
||||
migration, err := loadMigration(p.name, 1, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get Tables list.
|
||||
tables, err := p.SchemaTables(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tablesRename := []string{
|
||||
tablePre1Config,
|
||||
tablePre1TOTPSecrets,
|
||||
tablePre1IdentityVerificationTokens,
|
||||
tableU2FDevices,
|
||||
tableUserPreferences,
|
||||
tableAuthenticationLogs,
|
||||
tableAlphaPreferences,
|
||||
tableAlphaIdentityVerificationTokens,
|
||||
tableAlphaAuthenticationLogs,
|
||||
tableAlphaPreferencesTableName,
|
||||
tableAlphaSecondFactorPreferences,
|
||||
tableAlphaTOTPSecrets,
|
||||
tableAlphaU2FDeviceHandles,
|
||||
}
|
||||
|
||||
if err = p.schemaMigratePre1Rename(ctx, tables, tablesRename); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = p.db.ExecContext(ctx, migration.Query); err != nil {
|
||||
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
|
||||
}
|
||||
|
||||
if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertUserPreferencesFromSelect),
|
||||
tableUserPreferences, tablePrefixBackup+tableUserPreferences)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = p.schemaMigratePre1To1AuthenticationLogs(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = p.schemaMigratePre1To1U2F(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = p.schemaMigratePre1To1TOTP(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, table := range tablesRename {
|
||||
if _, err = p.db.Exec(fmt.Sprintf(p.db.Rebind(queryFmtDropTableIfExists), tablePrefixBackup+table)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return p.schemaMigrateFinalizeAdvanced(ctx, -1, 1)
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigratePre1Rename(ctx context.Context, tables, tablesRename []string) (err error) {
|
||||
// Rename Tables and Indexes.
|
||||
for _, table := range tables {
|
||||
if !utils.IsStringInSlice(table, tablesRename) {
|
||||
continue
|
||||
}
|
||||
|
||||
tableNew := tablePrefixBackup + table
|
||||
|
||||
if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.sqlFmtRenameTable, table, tableNew)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.name == providerPostgres {
|
||||
if table == tableU2FDevices || table == tableUserPreferences {
|
||||
if _, err = p.db.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s RENAME CONSTRAINT %s_pkey TO %s_pkey;`,
|
||||
tableNew, table, tableNew)); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigratePre1To1Rollback(ctx context.Context, up bool) (err error) {
|
||||
if up {
|
||||
migration, err := loadMigration(p.name, 1, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = p.db.ExecContext(ctx, migration.Query); err != nil {
|
||||
return fmt.Errorf(errFmtFailedMigration, migration.Version, migration.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
tables, err := p.SchemaTables(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if !strings.HasPrefix(table, tablePrefixBackup) {
|
||||
continue
|
||||
}
|
||||
|
||||
tableNew := strings.Replace(table, tablePrefixBackup, "", 1)
|
||||
if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.sqlFmtRenameTable, table, tableNew)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.name == providerPostgres && (tableNew == tableU2FDevices || tableNew == tableUserPreferences) {
|
||||
if _, err = p.db.ExecContext(ctx, fmt.Sprintf(`ALTER TABLE %s RENAME CONSTRAINT %s_pkey TO %s_pkey;`,
|
||||
tableNew, table, tableNew)); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigratePre1To1AuthenticationLogs(ctx context.Context) (err error) {
|
||||
for page := 0; true; page++ {
|
||||
attempts, err := p.schemaMigratePre1To1AuthenticationLogsGetRows(ctx, page)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
break
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
for _, attempt := range attempts {
|
||||
_, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1InsertAuthenticationLogs), tableAuthenticationLogs), attempt.Username, attempt.Successful, attempt.Time)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(attempts) != 100 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigratePre1To1AuthenticationLogsGetRows(ctx context.Context, page int) (attempts []models.AuthenticationAttempt, err error) {
|
||||
rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1SelectAuthenticationLogs), tablePrefixBackup+tableAuthenticationLogs), page*100)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attempts = make([]models.AuthenticationAttempt, 0, 100)
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
username string
|
||||
successful bool
|
||||
timestamp int64
|
||||
)
|
||||
|
||||
err = rows.Scan(&username, &successful, ×tamp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attempts = append(attempts, models.AuthenticationAttempt{Username: username, Successful: successful, Time: time.Unix(timestamp, 0)})
|
||||
}
|
||||
|
||||
return attempts, nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigratePre1To1TOTP(ctx context.Context) (err error) {
|
||||
rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1SelectTOTPConfigurations), tablePrefixBackup+tablePre1TOTPSecrets))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var totpConfigs []models.TOTPConfiguration
|
||||
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
p.log.Errorf(logFmtErrClosingConn, err)
|
||||
}
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
var username, secret string
|
||||
|
||||
err = rows.Scan(&username, &secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Add encryption migration here.
|
||||
encryptedSecret := "encrypted:" + secret
|
||||
|
||||
totpConfigs = append(totpConfigs, models.TOTPConfiguration{Username: username, Secret: encryptedSecret})
|
||||
}
|
||||
|
||||
for _, config := range totpConfigs {
|
||||
_, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertTOTPConfiguration), tableTOTPConfigurations), config.Username, config.Secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigratePre1To1U2F(ctx context.Context) (err error) {
|
||||
rows, err := p.db.Queryx(fmt.Sprintf(p.db.Rebind(queryFmtPre1To1SelectU2FDevices), tablePrefixBackup+tableU2FDevices))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
p.log.Errorf(logFmtErrClosingConn, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var devices []models.U2FDevice
|
||||
|
||||
for rows.Next() {
|
||||
var username, keyHandleBase64, publicKeyBase64 string
|
||||
|
||||
err = rows.Scan(&username, &keyHandleBase64, &publicKeyBase64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyHandle, err := base64.StdEncoding.DecodeString(keyHandleBase64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKey, err := base64.StdEncoding.DecodeString(publicKeyBase64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
devices = append(devices, models.U2FDevice{Username: username, KeyHandle: keyHandle, PublicKey: publicKey})
|
||||
}
|
||||
|
||||
for _, device := range devices {
|
||||
_, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1To1InsertU2FDevice), tableU2FDevices), device.Username, device.KeyHandle, device.PublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigrate1ToPre1(ctx context.Context) (err error) {
|
||||
tables, err := p.SchemaTables(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tablesRename := []string{
|
||||
tableMigrations,
|
||||
tableTOTPConfigurations,
|
||||
tableIdentityVerification,
|
||||
tableU2FDevices,
|
||||
tableDUODevices,
|
||||
tableUserPreferences,
|
||||
tableAuthenticationLogs,
|
||||
}
|
||||
|
||||
if err = p.schemaMigratePre1Rename(ctx, tables, tablesRename); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := p.db.ExecContext(ctx, queryCreatePre1); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertUserPreferencesFromSelect),
|
||||
tableUserPreferences, tablePrefixBackup+tableUserPreferences)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = p.schemaMigrate1ToPre1AuthenticationLogs(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = p.schemaMigrate1ToPre1U2F(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = p.schemaMigrate1ToPre1TOTP(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queryFmtDropTableRebound := p.db.Rebind(queryFmtDropTableIfExists)
|
||||
|
||||
for _, table := range tablesRename {
|
||||
if _, err = p.db.Exec(fmt.Sprintf(queryFmtDropTableRebound, tablePrefixBackup+table)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigrate1ToPre1AuthenticationLogs(ctx context.Context) (err error) {
|
||||
for page := 0; true; page++ {
|
||||
attempts, err := p.schemaMigrate1ToPre1AuthenticationLogsGetRows(ctx, page)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
break
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
for _, attempt := range attempts {
|
||||
_, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1InsertAuthenticationLogs), tableAuthenticationLogs), attempt.Username, attempt.Successful, attempt.Time.Unix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(attempts) != 100 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigrate1ToPre1AuthenticationLogsGetRows(ctx context.Context, page int) (attempts []models.AuthenticationAttempt, err error) {
|
||||
rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1SelectAuthenticationLogs), tablePrefixBackup+tableAuthenticationLogs), page*100)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attempts = make([]models.AuthenticationAttempt, 0, 100)
|
||||
|
||||
var attempt models.AuthenticationAttempt
|
||||
for rows.Next() {
|
||||
err = rows.StructScan(&attempt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
attempts = append(attempts, attempt)
|
||||
}
|
||||
|
||||
return attempts, nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigrate1ToPre1TOTP(ctx context.Context) (err error) {
|
||||
rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1SelectTOTPConfigurations), tablePrefixBackup+tableTOTPConfigurations))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var totpConfigs []models.TOTPConfiguration
|
||||
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
p.log.Errorf(logFmtErrClosingConn, err)
|
||||
}
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
var username, encryptedSecret string
|
||||
|
||||
err = rows.Scan(&username, &encryptedSecret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Fix.
|
||||
// TODO: Add DECRYPTION migration here.
|
||||
decryptedSecret := strings.Replace(encryptedSecret, "encrypted:", "", 1)
|
||||
|
||||
totpConfigs = append(totpConfigs, models.TOTPConfiguration{Username: username, Secret: decryptedSecret})
|
||||
}
|
||||
|
||||
for _, config := range totpConfigs {
|
||||
_, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmtPre1InsertTOTPConfiguration), tablePre1TOTPSecrets), config.Username, config.Secret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaMigrate1ToPre1U2F(ctx context.Context) (err error) {
|
||||
rows, err := p.db.QueryxContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1SelectU2FDevices), tablePrefixBackup+tableU2FDevices))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
p.log.Errorf(logFmtErrClosingConn, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var (
|
||||
devices []models.U2FDevice
|
||||
device models.U2FDevice
|
||||
)
|
||||
|
||||
for rows.Next() {
|
||||
err = rows.StructScan(&device)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
devices = append(devices, device)
|
||||
}
|
||||
|
||||
for _, device := range devices {
|
||||
_, err = p.db.ExecContext(ctx, fmt.Sprintf(p.db.Rebind(queryFmt1ToPre1InsertU2FDevice), tableU2FDevices), device.Username, base64.StdEncoding.EncodeToString(device.KeyHandle), base64.StdEncoding.EncodeToString(device.PublicKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
134
internal/storage/sql_provider_schema_test.go
Normal file
134
internal/storage/sql_provider_schema_test.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShouldReturnErrOnTargetSameAsCurrent(t *testing.T) {
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerSQLite, true, 1, 1),
|
||||
fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerSQLite, false, 1, 1),
|
||||
fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerSQLite, false, 2, 2),
|
||||
fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 2, 2))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerMySQL, false, 1, 1),
|
||||
fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerPostgres, false, 1, 1),
|
||||
fmt.Sprintf(ErrFmtMigrateAlreadyOnTargetVersion, 1, 1))
|
||||
}
|
||||
|
||||
func TestShouldReturnErrOnUpMigrationTargetVersionLessTHanCurrent(t *testing.T) {
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerPostgres, true, 0, testLatestVersion),
|
||||
fmt.Sprintf(ErrFmtMigrateUpTargetLessThanCurrent, 0, testLatestVersion))
|
||||
|
||||
assert.NoError(t,
|
||||
schemaMigrateChecks(providerPostgres, true, testLatestVersion, 0))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerSQLite, true, 0, testLatestVersion),
|
||||
fmt.Sprintf(ErrFmtMigrateUpTargetLessThanCurrent, 0, testLatestVersion))
|
||||
|
||||
assert.NoError(t,
|
||||
schemaMigrateChecks(providerSQLite, true, testLatestVersion, 0))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerMySQL, true, 0, testLatestVersion),
|
||||
fmt.Sprintf(ErrFmtMigrateUpTargetLessThanCurrent, 0, testLatestVersion))
|
||||
|
||||
assert.NoError(t,
|
||||
schemaMigrateChecks(providerMySQL, true, testLatestVersion, 0))
|
||||
}
|
||||
|
||||
func TestMigrationUpShouldReturnErrOnAlreadyLatest(t *testing.T) {
|
||||
assert.Equal(t,
|
||||
ErrSchemaAlreadyUpToDate,
|
||||
schemaMigrateChecks(providerPostgres, true, SchemaLatest, testLatestVersion))
|
||||
|
||||
assert.Equal(t,
|
||||
ErrSchemaAlreadyUpToDate,
|
||||
schemaMigrateChecks(providerMySQL, true, SchemaLatest, testLatestVersion))
|
||||
|
||||
assert.Equal(t,
|
||||
ErrSchemaAlreadyUpToDate,
|
||||
schemaMigrateChecks(providerSQLite, true, SchemaLatest, testLatestVersion))
|
||||
}
|
||||
|
||||
func TestShouldReturnErrOnVersionDoesntExits(t *testing.T) {
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerPostgres, true, SchemaLatest-1, testLatestVersion),
|
||||
fmt.Sprintf(ErrFmtMigrateUpTargetGreaterThanLatest, SchemaLatest-1, testLatestVersion))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerMySQL, true, SchemaLatest-1, testLatestVersion),
|
||||
fmt.Sprintf(ErrFmtMigrateUpTargetGreaterThanLatest, SchemaLatest-1, testLatestVersion))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerSQLite, true, SchemaLatest-1, testLatestVersion),
|
||||
fmt.Sprintf(ErrFmtMigrateUpTargetGreaterThanLatest, SchemaLatest-1, testLatestVersion))
|
||||
}
|
||||
|
||||
func TestMigrationDownShouldReturnErrOnTargetLessThanPre1(t *testing.T) {
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerSQLite, false, -4, testLatestVersion),
|
||||
fmt.Sprintf(ErrFmtMigrateDownTargetLessThanMinimum, -4))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerMySQL, false, -2, testLatestVersion),
|
||||
fmt.Sprintf(ErrFmtMigrateDownTargetLessThanMinimum, -2))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerPostgres, false, -2, testLatestVersion),
|
||||
fmt.Sprintf(ErrFmtMigrateDownTargetLessThanMinimum, -2))
|
||||
|
||||
assert.NoError(t,
|
||||
schemaMigrateChecks(providerPostgres, false, -1, testLatestVersion))
|
||||
}
|
||||
|
||||
func TestMigrationDownShouldReturnErrOnTargetVersionGreaterThanCurrent(t *testing.T) {
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerSQLite, false, testLatestVersion, 0),
|
||||
fmt.Sprintf(ErrFmtMigrateDownTargetGreaterThanCurrent, testLatestVersion, 0))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerMySQL, false, testLatestVersion, 0),
|
||||
fmt.Sprintf(ErrFmtMigrateDownTargetGreaterThanCurrent, testLatestVersion, 0))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerPostgres, false, testLatestVersion, 0),
|
||||
fmt.Sprintf(ErrFmtMigrateDownTargetGreaterThanCurrent, testLatestVersion, 0))
|
||||
}
|
||||
|
||||
func TestShouldReturnErrWhenCurrentIsGreaterThanLatest(t *testing.T) {
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerPostgres, true, SchemaLatest-4, SchemaLatest-5),
|
||||
fmt.Sprintf(errFmtSchemaCurrentGreaterThanLatestKnown, testLatestVersion))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerMySQL, true, SchemaLatest-4, SchemaLatest-5),
|
||||
fmt.Sprintf(errFmtSchemaCurrentGreaterThanLatestKnown, testLatestVersion))
|
||||
|
||||
assert.EqualError(t,
|
||||
schemaMigrateChecks(providerSQLite, true, SchemaLatest-4, SchemaLatest-5),
|
||||
fmt.Sprintf(errFmtSchemaCurrentGreaterThanLatestKnown, testLatestVersion))
|
||||
}
|
||||
|
||||
func TestSchemaVersionToString(t *testing.T) {
|
||||
assert.Equal(t, "unknown", SchemaVersionToString(-2))
|
||||
assert.Equal(t, "pre1", SchemaVersionToString(-1))
|
||||
assert.Equal(t, "N/A", SchemaVersionToString(0))
|
||||
assert.Equal(t, "1", SchemaVersionToString(1))
|
||||
assert.Equal(t, "2", SchemaVersionToString(2))
|
||||
}
|
|
@ -1,400 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/authentication"
|
||||
"github.com/authelia/authelia/v4/internal/models"
|
||||
)
|
||||
|
||||
const currentSchemaMockSchemaVersion = "1"
|
||||
|
||||
func TestSQLInitializeDatabase(t *testing.T) {
|
||||
provider, mock := NewSQLMockProvider()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"name"})
|
||||
mock.ExpectQuery(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'").
|
||||
WillReturnRows(rows)
|
||||
|
||||
mock.ExpectBegin()
|
||||
|
||||
keys := make([]string, 0, len(sqlUpgradeCreateTableStatements[1]))
|
||||
for k := range sqlUpgradeCreateTableStatements[1] {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, table := range keys {
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("CREATE TABLE %s .*", table)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
}
|
||||
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s .*", authenticationLogsTableName)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("REPLACE INTO %s \\(category, key_name, value\\) VALUES \\(\\?, \\?, \\?\\)", configTableName)).
|
||||
WithArgs("schema", "version", "1").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
mock.ExpectCommit()
|
||||
|
||||
err := provider.initialize(provider.db)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSQLUpgradeDatabase(t *testing.T) {
|
||||
provider, mock := NewSQLMockProvider()
|
||||
|
||||
mock.ExpectQuery(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
||||
AddRow(userPreferencesTableName).
|
||||
AddRow(identityVerificationTokensTableName).
|
||||
AddRow(totpSecretsTableName).
|
||||
AddRow(u2fDeviceHandlesTableName).
|
||||
AddRow(authenticationLogsTableName))
|
||||
|
||||
mock.ExpectBegin()
|
||||
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("CREATE TABLE %s .*", configTableName)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("CREATE INDEX IF NOT EXISTS usr_time_idx ON %s .*", authenticationLogsTableName)).
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("REPLACE INTO %s \\(category, key_name, value\\) VALUES \\(\\?, \\?, \\?\\)", configTableName)).
|
||||
WithArgs("schema", "version", "1").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
mock.ExpectCommit()
|
||||
|
||||
err := provider.initialize(provider.db)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSQLProviderMethodsAuthenticationLogs(t *testing.T) {
|
||||
provider, mock := NewSQLMockProvider()
|
||||
|
||||
mock.ExpectQuery(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
||||
AddRow(userPreferencesTableName).
|
||||
AddRow(identityVerificationTokensTableName).
|
||||
AddRow(totpSecretsTableName).
|
||||
AddRow(u2fDeviceHandlesTableName).
|
||||
AddRow(authenticationLogsTableName).
|
||||
AddRow(configTableName))
|
||||
|
||||
args := []driver.Value{"schema", "version"}
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"value"}).
|
||||
AddRow("1"))
|
||||
|
||||
err := provider.initialize(provider.db)
|
||||
assert.NoError(t, err)
|
||||
|
||||
attempts := []models.AuthenticationAttempt{
|
||||
{Username: unitTestUser, Successful: true, Time: time.Unix(1577880001, 0)},
|
||||
{Username: unitTestUser, Successful: true, Time: time.Unix(1577880002, 0)},
|
||||
{Username: unitTestUser, Successful: false, Time: time.Unix(1577880003, 0)},
|
||||
}
|
||||
|
||||
rows := sqlmock.NewRows([]string{"successful", "time"})
|
||||
|
||||
for id, attempt := range attempts {
|
||||
args = []driver.Value{attempt.Username, attempt.Successful, attempt.Time.Unix()}
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("INSERT INTO %s \\(username, successful, time\\) VALUES \\(\\?, \\?, \\?\\)", authenticationLogsTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnResult(sqlmock.NewResult(int64(id), 1))
|
||||
|
||||
err := provider.AppendAuthenticationLog(attempt)
|
||||
assert.NoError(t, err)
|
||||
rows.AddRow(attempt.Successful, attempt.Time.Unix())
|
||||
}
|
||||
|
||||
args = []driver.Value{1577880000, unitTestUser}
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT successful, time FROM %s WHERE time>\\? AND username=\\? ORDER BY time DESC", authenticationLogsTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnRows(rows)
|
||||
|
||||
after := time.Unix(1577880000, 0)
|
||||
results, err := provider.LoadLatestAuthenticationLogs(unitTestUser, after)
|
||||
assert.NoError(t, err)
|
||||
require.Len(t, results, 3)
|
||||
assert.Equal(t, unitTestUser, results[0].Username)
|
||||
assert.Equal(t, true, results[0].Successful)
|
||||
assert.Equal(t, time.Unix(1577880001, 0), results[0].Time)
|
||||
assert.Equal(t, unitTestUser, results[1].Username)
|
||||
assert.Equal(t, true, results[1].Successful)
|
||||
assert.Equal(t, time.Unix(1577880002, 0), results[1].Time)
|
||||
assert.Equal(t, unitTestUser, results[2].Username)
|
||||
assert.Equal(t, false, results[2].Successful)
|
||||
assert.Equal(t, time.Unix(1577880003, 0), results[2].Time)
|
||||
|
||||
// Test Blank Rows.
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT successful, time FROM %s WHERE time>\\? AND username=\\? ORDER BY time DESC", authenticationLogsTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"successful", "time"}))
|
||||
|
||||
results, err = provider.LoadLatestAuthenticationLogs(unitTestUser, after)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, results, 0)
|
||||
}
|
||||
|
||||
func TestSQLProviderMethodsPreferred(t *testing.T) {
|
||||
provider, mock := NewSQLMockProvider()
|
||||
|
||||
mock.ExpectQuery(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
||||
AddRow(userPreferencesTableName).
|
||||
AddRow(identityVerificationTokensTableName).
|
||||
AddRow(totpSecretsTableName).
|
||||
AddRow(u2fDeviceHandlesTableName).
|
||||
AddRow(authenticationLogsTableName).
|
||||
AddRow(configTableName))
|
||||
|
||||
args := []driver.Value{"schema", "version"}
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"value"}).
|
||||
AddRow(currentSchemaMockSchemaVersion))
|
||||
|
||||
err := provider.initialize(provider.db)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("REPLACE INTO %s \\(username, second_factor_method\\) VALUES \\(\\?, \\?\\)", userPreferencesTableName)).
|
||||
WithArgs(unitTestUser, authentication.TOTP).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err = provider.SavePreferred2FAMethod(unitTestUser, authentication.TOTP)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=\\?", userPreferencesTableName)).
|
||||
WithArgs(unitTestUser).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"second_factor_method"}).AddRow(authentication.TOTP))
|
||||
|
||||
method, err := provider.LoadPreferred2FAMethod(unitTestUser)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, authentication.TOTP, method)
|
||||
|
||||
// Test Blank Rows.
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=\\?", userPreferencesTableName)).
|
||||
WithArgs(unitTestUser).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"second_factor_method"}))
|
||||
|
||||
method, err = provider.LoadPreferred2FAMethod(unitTestUser)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "", method)
|
||||
}
|
||||
|
||||
func TestSQLProviderMethodsTOTP(t *testing.T) {
|
||||
provider, mock := NewSQLMockProvider()
|
||||
|
||||
mock.ExpectQuery(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
||||
AddRow(userPreferencesTableName).
|
||||
AddRow(identityVerificationTokensTableName).
|
||||
AddRow(totpSecretsTableName).
|
||||
AddRow(u2fDeviceHandlesTableName).
|
||||
AddRow(authenticationLogsTableName).
|
||||
AddRow(configTableName))
|
||||
|
||||
args := []driver.Value{"schema", "version"}
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"value"}).
|
||||
AddRow(currentSchemaMockSchemaVersion))
|
||||
|
||||
err := provider.initialize(provider.db)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pretendSecret := "abc123"
|
||||
args = []driver.Value{unitTestUser, pretendSecret}
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("REPLACE INTO %s \\(username, secret\\) VALUES \\(\\?, \\?\\)", totpSecretsTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err = provider.SaveTOTPSecret(unitTestUser, pretendSecret)
|
||||
assert.NoError(t, err)
|
||||
|
||||
args = []driver.Value{unitTestUser}
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT secret FROM %s WHERE username=\\?", totpSecretsTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"secret"}).AddRow(pretendSecret))
|
||||
|
||||
secret, err := provider.LoadTOTPSecret(unitTestUser)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pretendSecret, secret)
|
||||
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("DELETE FROM %s WHERE username=\\?", totpSecretsTableName)).
|
||||
WithArgs(unitTestUser).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err = provider.DeleteTOTPSecret(unitTestUser)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT secret FROM %s WHERE username=\\?", totpSecretsTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"secret"}))
|
||||
|
||||
// Test Blank Rows
|
||||
secret, err = provider.LoadTOTPSecret(unitTestUser)
|
||||
assert.EqualError(t, err, "no TOTP secret registered")
|
||||
assert.Equal(t, "", secret)
|
||||
}
|
||||
|
||||
func TestSQLProviderMethodsU2F(t *testing.T) {
|
||||
provider, mock := NewSQLMockProvider()
|
||||
|
||||
mock.ExpectQuery(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
||||
AddRow(userPreferencesTableName).
|
||||
AddRow(identityVerificationTokensTableName).
|
||||
AddRow(totpSecretsTableName).
|
||||
AddRow(u2fDeviceHandlesTableName).
|
||||
AddRow(authenticationLogsTableName).
|
||||
AddRow(configTableName))
|
||||
|
||||
args := []driver.Value{"schema", "version"}
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"value"}).
|
||||
AddRow(currentSchemaMockSchemaVersion))
|
||||
|
||||
err := provider.initialize(provider.db)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pretendKeyHandle := []byte("abc")
|
||||
pretendPublicKey := []byte("123")
|
||||
pretendKeyHandleB64 := base64.StdEncoding.EncodeToString(pretendKeyHandle)
|
||||
pretendPublicKeyB64 := base64.StdEncoding.EncodeToString(pretendPublicKey)
|
||||
|
||||
args = []driver.Value{unitTestUser, pretendKeyHandleB64, pretendPublicKeyB64}
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("REPLACE INTO %s \\(username, keyHandle, publicKey\\) VALUES \\(\\?, \\?, \\?\\)", u2fDeviceHandlesTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err = provider.SaveU2FDeviceHandle(unitTestUser, pretendKeyHandle, pretendPublicKey)
|
||||
assert.NoError(t, err)
|
||||
|
||||
args = []driver.Value{unitTestUser}
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=\\?", u2fDeviceHandlesTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"keyHandle", "publicKey"}).
|
||||
AddRow(pretendKeyHandleB64, pretendPublicKeyB64))
|
||||
|
||||
keyHandle, publicKey, err := provider.LoadU2FDeviceHandle(unitTestUser)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, pretendKeyHandle, keyHandle)
|
||||
assert.Equal(t, pretendPublicKey, publicKey)
|
||||
|
||||
// Test Blank Rows.
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=\\?", u2fDeviceHandlesTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"keyHandle", "publicKey"}))
|
||||
|
||||
keyHandle, publicKey, err = provider.LoadU2FDeviceHandle(unitTestUser)
|
||||
assert.EqualError(t, err, "no U2F device handle found")
|
||||
assert.Equal(t, []byte(nil), keyHandle)
|
||||
assert.Equal(t, []byte(nil), publicKey)
|
||||
}
|
||||
|
||||
func TestSQLProviderMethodsIdentityVerificationTokens(t *testing.T) {
|
||||
provider, mock := NewSQLMockProvider()
|
||||
|
||||
mock.ExpectQuery(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"name"}).
|
||||
AddRow(userPreferencesTableName).
|
||||
AddRow(identityVerificationTokensTableName).
|
||||
AddRow(totpSecretsTableName).
|
||||
AddRow(u2fDeviceHandlesTableName).
|
||||
AddRow(authenticationLogsTableName).
|
||||
AddRow(configTableName))
|
||||
|
||||
args := []driver.Value{"schema", "version"}
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT value FROM %s WHERE category=\\? AND key_name=\\?", configTableName)).
|
||||
WithArgs(args...).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"value"}).
|
||||
AddRow(currentSchemaMockSchemaVersion))
|
||||
|
||||
err := provider.initialize(provider.db)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fakeIdentityVerificationToken := "abc"
|
||||
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("INSERT INTO %s \\(token\\) VALUES \\(\\?\\)", identityVerificationTokensTableName)).
|
||||
WithArgs(fakeIdentityVerificationToken).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
err = provider.SaveIdentityVerificationToken(fakeIdentityVerificationToken)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT EXISTS \\(SELECT \\* FROM %s WHERE token=\\?\\)", identityVerificationTokensTableName)).
|
||||
WithArgs(fakeIdentityVerificationToken).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"EXISTS"}).
|
||||
AddRow(true))
|
||||
|
||||
valid, err := provider.FindIdentityVerificationToken(fakeIdentityVerificationToken)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, valid)
|
||||
|
||||
mock.ExpectExec(
|
||||
fmt.Sprintf("DELETE FROM %s WHERE token=\\?", identityVerificationTokensTableName)).
|
||||
WithArgs(fakeIdentityVerificationToken).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
err = provider.RemoveIdentityVerificationToken(fakeIdentityVerificationToken)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mock.ExpectQuery(
|
||||
fmt.Sprintf("SELECT EXISTS \\(SELECT \\* FROM %s WHERE token=\\?\\)", identityVerificationTokensTableName)).
|
||||
WithArgs(fakeIdentityVerificationToken).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"EXISTS"}).
|
||||
AddRow(false))
|
||||
|
||||
valid, err = provider.FindIdentityVerificationToken(fakeIdentityVerificationToken)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, valid)
|
||||
}
|
|
@ -1,58 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // Load the SQLite Driver used in the connection string.
|
||||
)
|
||||
|
||||
// SQLiteProvider is a SQLite3 provider.
|
||||
type SQLiteProvider struct {
|
||||
SQLProvider
|
||||
}
|
||||
|
||||
// NewSQLiteProvider constructs a SQLite provider.
|
||||
func NewSQLiteProvider(path string) *SQLiteProvider {
|
||||
provider := SQLiteProvider{
|
||||
SQLProvider{
|
||||
name: "sqlite",
|
||||
|
||||
sqlUpgradesCreateTableStatements: sqlUpgradeCreateTableStatements,
|
||||
sqlUpgradesCreateTableIndexesStatements: sqlUpgradesCreateTableIndexesStatements,
|
||||
|
||||
sqlGetPreferencesByUsername: fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=?", userPreferencesTableName),
|
||||
sqlUpsertSecondFactorPreference: fmt.Sprintf("REPLACE INTO %s (username, second_factor_method) VALUES (?, ?)", userPreferencesTableName),
|
||||
|
||||
sqlTestIdentityVerificationTokenExistence: fmt.Sprintf("SELECT EXISTS (SELECT * FROM %s WHERE token=?)", identityVerificationTokensTableName),
|
||||
sqlInsertIdentityVerificationToken: fmt.Sprintf("INSERT INTO %s (token) VALUES (?)", identityVerificationTokensTableName),
|
||||
sqlDeleteIdentityVerificationToken: fmt.Sprintf("DELETE FROM %s WHERE token=?", identityVerificationTokensTableName),
|
||||
|
||||
sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName),
|
||||
sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName),
|
||||
sqlDeleteTOTPSecret: fmt.Sprintf("DELETE FROM %s WHERE username=?", totpSecretsTableName),
|
||||
|
||||
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName),
|
||||
sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName),
|
||||
|
||||
sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName),
|
||||
sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName),
|
||||
|
||||
sqlGetExistingTables: "SELECT name FROM sqlite_master WHERE type='table'",
|
||||
|
||||
sqlConfigSetValue: fmt.Sprintf("REPLACE INTO %s (category, key_name, value) VALUES (?, ?, ?)", configTableName),
|
||||
sqlConfigGetValue: fmt.Sprintf("SELECT value FROM %s WHERE category=? AND key_name=?", configTableName),
|
||||
},
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
provider.log.Fatalf("Unable to create SQL database %s: %s", path, err)
|
||||
}
|
||||
|
||||
if err := provider.initialize(db); err != nil {
|
||||
provider.log.Fatalf("Unable to initialize SQL database %s: %s", path, err)
|
||||
}
|
||||
|
||||
return &provider
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
// SQLMockProvider is a SQLMock provider.
|
||||
type SQLMockProvider struct {
|
||||
SQLProvider
|
||||
}
|
||||
|
||||
// NewSQLMockProvider constructs a SQLMock provider.
|
||||
func NewSQLMockProvider() (*SQLMockProvider, sqlmock.Sqlmock) {
|
||||
provider := SQLMockProvider{
|
||||
SQLProvider{
|
||||
name: "sqlmock",
|
||||
|
||||
sqlUpgradesCreateTableStatements: sqlUpgradeCreateTableStatements,
|
||||
sqlUpgradesCreateTableIndexesStatements: sqlUpgradesCreateTableIndexesStatements,
|
||||
|
||||
sqlGetPreferencesByUsername: fmt.Sprintf("SELECT second_factor_method FROM %s WHERE username=?", userPreferencesTableName),
|
||||
sqlUpsertSecondFactorPreference: fmt.Sprintf("REPLACE INTO %s (username, second_factor_method) VALUES (?, ?)", userPreferencesTableName),
|
||||
|
||||
sqlTestIdentityVerificationTokenExistence: fmt.Sprintf("SELECT EXISTS (SELECT * FROM %s WHERE token=?)", identityVerificationTokensTableName),
|
||||
sqlInsertIdentityVerificationToken: fmt.Sprintf("INSERT INTO %s (token) VALUES (?)", identityVerificationTokensTableName),
|
||||
sqlDeleteIdentityVerificationToken: fmt.Sprintf("DELETE FROM %s WHERE token=?", identityVerificationTokensTableName),
|
||||
|
||||
sqlGetTOTPSecretByUsername: fmt.Sprintf("SELECT secret FROM %s WHERE username=?", totpSecretsTableName),
|
||||
sqlUpsertTOTPSecret: fmt.Sprintf("REPLACE INTO %s (username, secret) VALUES (?, ?)", totpSecretsTableName),
|
||||
sqlDeleteTOTPSecret: fmt.Sprintf("DELETE FROM %s WHERE username=?", totpSecretsTableName),
|
||||
|
||||
sqlGetU2FDeviceHandleByUsername: fmt.Sprintf("SELECT keyHandle, publicKey FROM %s WHERE username=?", u2fDeviceHandlesTableName),
|
||||
sqlUpsertU2FDeviceHandle: fmt.Sprintf("REPLACE INTO %s (username, keyHandle, publicKey) VALUES (?, ?, ?)", u2fDeviceHandlesTableName),
|
||||
|
||||
sqlInsertAuthenticationLog: fmt.Sprintf("INSERT INTO %s (username, successful, time) VALUES (?, ?, ?)", authenticationLogsTableName),
|
||||
sqlGetLatestAuthenticationLogs: fmt.Sprintf("SELECT successful, time FROM %s WHERE time>? AND username=? ORDER BY time DESC", authenticationLogsTableName),
|
||||
|
||||
sqlGetExistingTables: "SELECT name FROM sqlite_master WHERE type='table'",
|
||||
|
||||
sqlConfigSetValue: fmt.Sprintf("REPLACE INTO %s (category, key_name, value) VALUES (?, ?, ?)", configTableName),
|
||||
sqlConfigGetValue: fmt.Sprintf("SELECT value FROM %s WHERE category=? AND key_name=?", configTableName),
|
||||
},
|
||||
}
|
||||
|
||||
db, mock, err := sqlmock.New()
|
||||
|
||||
if err != nil {
|
||||
provider.log.Fatalf("Unable to create SQL database: %s", err)
|
||||
}
|
||||
|
||||
provider.db = db
|
||||
|
||||
/*
|
||||
We do initialize in the tests rather than in the new up.
|
||||
*/
|
||||
|
||||
return &provider, mock
|
||||
}
|
|
@ -1,18 +1,28 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// SchemaVersion is a simple int representation of the schema version.
|
||||
type SchemaVersion int
|
||||
|
||||
// ToString converts the schema version into a string and returns that converted value.
|
||||
func (s SchemaVersion) ToString() string {
|
||||
return strconv.Itoa(int(s))
|
||||
// SchemaMigration represents an intended migration.
|
||||
type SchemaMigration struct {
|
||||
Version int
|
||||
Name string
|
||||
Provider string
|
||||
Up bool
|
||||
Query string
|
||||
}
|
||||
|
||||
type transaction interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
// Before returns the version the schema should be at Before the migration is applied.
|
||||
func (m SchemaMigration) Before() (before int) {
|
||||
if m.Up {
|
||||
return m.Version - 1
|
||||
}
|
||||
|
||||
return m.Version
|
||||
}
|
||||
|
||||
// After returns the version the schema will be at After the migration is applied.
|
||||
func (m SchemaMigration) After() (after int) {
|
||||
if m.Up {
|
||||
return m.Version
|
||||
}
|
||||
|
||||
return m.Version - 1
|
||||
}
|
||||
|
|
|
@ -1,76 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/authelia/authelia/v4/internal/utils"
|
||||
)
|
||||
|
||||
func (p *SQLProvider) upgradeCreateTableStatements(tx transaction, statements map[string]string, existingTables []string) error {
|
||||
keys := make([]string, 0, len(statements))
|
||||
for k := range statements {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, table := range keys {
|
||||
if !utils.IsStringInSlice(table, existingTables) {
|
||||
_, err := tx.Exec(fmt.Sprintf(statements[table], table))
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create table %s: %v", table, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) upgradeRunMultipleStatements(tx transaction, statements []string) error {
|
||||
for _, statement := range statements {
|
||||
_, err := tx.Exec(statement)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeFinalize sets the schema version and logs a message, as well as any other future finalization tasks.
|
||||
func (p *SQLProvider) upgradeFinalize(tx transaction, version SchemaVersion) error {
|
||||
_, err := tx.Exec(p.sqlConfigSetValue, "schema", "version", version.ToString())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.log.Debugf("%s%d", storageSchemaUpgradeMessage, version)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// upgradeSchemaToVersion001 upgrades the schema to version 1.
|
||||
func (p *SQLProvider) upgradeSchemaToVersion001(tx transaction, tables []string) error {
|
||||
version := SchemaVersion(1)
|
||||
|
||||
err := p.upgradeCreateTableStatements(tx, p.sqlUpgradesCreateTableStatements[version], tables)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip mysql create index statements. It doesn't support CREATE INDEX IF NOT EXIST. May be able to work around this with an Index struct.
|
||||
if p.name != "mysql" {
|
||||
err = p.upgradeRunMultipleStatements(tx, p.sqlUpgradesCreateTableIndexesStatements[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create index: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = p.upgradeFinalize(tx, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -122,7 +122,9 @@ func (s *StandaloneWebDriverSuite) TestShouldCheckUserIsAskedToRegisterDevice()
|
|||
|
||||
// Clean up any TOTP secret already in DB.
|
||||
provider := storage.NewSQLiteProvider("/tmp/db.sqlite3")
|
||||
require.NoError(s.T(), provider.DeleteTOTPSecret(username))
|
||||
|
||||
require.NoError(s.T(), provider.StartupCheck())
|
||||
require.NoError(s.T(), provider.DeleteTOTPConfiguration(ctx, username))
|
||||
|
||||
// Login one factor.
|
||||
s.doLoginOneFactor(s.T(), s.Context(ctx), username, password, false, "")
|
||||
|
|
Loading…
Reference in New Issue
Block a user