mirror of
https://github.com/0rangebananaspy/authelia.git
synced 2024-09-14 22:47:21 +07:00
feat(storage): encrypt u2f key (#2664)
Adds encryption to the U2F public keys. While the public keys cannot be used to authenticate, only to validate someone is authenticated, if a rogue operator changed these in the database they may be able to bypass 2FA. This prevents that.
This commit is contained in:
parent
104a61ecd6
commit
255aaeb2ad
|
@ -31,12 +31,12 @@ required: yes
|
|||
{: .label .label-config .label-red }
|
||||
</div>
|
||||
|
||||
The encryption key used to encrypt data in the database. It has a minimum length of 20 and must be provided. We encrypt
|
||||
data by creating a sha256 checksum of the provided value, and use that to encrypt the data with the AES-GCM 256bit
|
||||
algorithm.
|
||||
The encryption key used to encrypt data in the database. We encrypt data by creating a sha256 checksum of the provided
|
||||
value, and use that to encrypt the data with the AES-GCM 256bit algorithm.
|
||||
|
||||
The encrypted data in the database is as follows:
|
||||
- TOTP Secret
|
||||
The minimum length of this key is 20 characters, however we generally recommend above 64 characters.
|
||||
|
||||
See [securty measures](../../security/measures.md#storage-security-measures) for more information.
|
||||
|
||||
### local
|
||||
See [SQLite](./sqlite.md).
|
||||
|
|
|
@ -81,6 +81,54 @@ LDAP implementations vary, so please ask if you need some assistance in configur
|
|||
These protections can be [tuned](../configuration/authentication/ldap.md#refresh-interval) according to your security
|
||||
policy by changing refresh_interval, however we believe that 5 minutes is a fairly safe interval.
|
||||
|
||||
## Storage security measures
|
||||
|
||||
We force users to encrypt vulnerable data stored in the database. It is strongly advised you do not give this encryption
|
||||
key to anyone. In the instance of a database installation that multiple users have access to, you should aim to ensure
|
||||
that users who have access to the database do not also have access to this key.
|
||||
|
||||
The encrypted data in the database is as follows:
|
||||
|
||||
|Table |Column |Rational |
|
||||
|:-----------------:|:--------:|:----------------------------------------------------------------------------------------------------:|
|
||||
|totp_configurations|secret |Prevents a [Leaked Database](#leaked-database) or [Bad Actors](#bad-actors) from compromising security|
|
||||
|u2f_devices |public_key|Prevents [Bad Actors](#bad-actors) from compromising security |
|
||||
|
||||
### Leaked Database
|
||||
|
||||
A leaked database can reasonably compromise security if there are credentials that are not encrypted. Columns encrypted
|
||||
for this purpose prevent this attack vector.
|
||||
|
||||
### Bad Actors
|
||||
|
||||
A bad actor who has the SQL password and access to the database can theoretically change another users credential, this
|
||||
theoretically bypasses authentication. Columns encrypted for this purpose prevent this attack vector.
|
||||
|
||||
A bad actor may also be able to use data in the database to bypass 2FA silently depending on the credentials. In the
|
||||
instance of the U2F public key this is not possible, they can only change it which would eventually alert the user in
|
||||
question. But in the case of TOTP they can use the secret to authenticate without knowledge of the user in question.
|
||||
|
||||
### Encryption key management
|
||||
|
||||
You must supply the encryption key in the recommended method of a [secret](../configuration/secrets.md) or in one of
|
||||
the other methods available for [configuration](../configuration/index.md#configuration).
|
||||
|
||||
If you wish to change your encryption key for any reason you can do so using the following steps:
|
||||
|
||||
1. Run the `authelia --version` command to determine the version of Authelia you're running and either download that
|
||||
version or run another container of that version interactively. All the subsequent commands assume you're running
|
||||
the `authelia` binary in the current working directory. You will have to adjust this according to how you're running
|
||||
it.
|
||||
2. Run the `./authelia storage encryption change-key --help` command.
|
||||
3. Stop Authelia.
|
||||
- You can skip this step, however note that any data changed between the time you make the change and the time when
|
||||
you stop Authelia i.e. via user registering a device; will be encrypted with the incorrect key.
|
||||
4. Run the `./authelia storage encryption change-key` command with the appropriate parameters.
|
||||
- The help from step 1 will be useful here. The easiest method to accomplish this is with the `--config`,
|
||||
`--encryption-key`, and `--new-encryption-key` parameters.
|
||||
5. Update the encryption key Authelia uses on startup.
|
||||
6. Start Authelia.
|
||||
|
||||
## Notifier security measures (SMTP)
|
||||
|
||||
The SMTP Notifier implementation does not allow connections that are not secure without changing default configuration
|
||||
|
|
|
@ -52,7 +52,6 @@ If properly configured, Authelia guarantees the following for security of your u
|
|||
* Binding session cookies to single IP addresses.
|
||||
* Authenticate communication between Authelia and reverse proxy.
|
||||
* Securely transmit authentication data to backends (OAuth2 with bearer tokens).
|
||||
* Protect secrets stored in the database with encryption to prevent secrets leak by database exfiltration.
|
||||
* Least privilege on LDAP binding operations (currently administrative user is used to bind while it could be anonymous
|
||||
for most operations).
|
||||
* Extend the check of user group memberships to authentication backends other than LDAP (File currently).
|
||||
|
|
|
@ -65,7 +65,7 @@ func (s *FirstFactorSuite) TestShouldFailIfUserProviderCheckPasswordFail() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthType1FA,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
s.mock.Ctx.Request.SetBodyString(`{
|
||||
|
@ -93,7 +93,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsNotMarkedWhenProviderC
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthType1FA,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
s.mock.Ctx.Request.SetBodyString(`{
|
||||
|
@ -119,7 +119,7 @@ func (s *FirstFactorSuite) TestShouldCheckAuthenticationIsMarkedWhenInvalidCrede
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthType1FA,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
s.mock.Ctx.Request.SetBodyString(`{
|
||||
|
|
|
@ -34,21 +34,21 @@ func (s *HandlerRegisterU2FStep1Suite) TearDownTest() {
|
|||
s.mock.Close()
|
||||
}
|
||||
|
||||
func createToken(secret, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) {
|
||||
verification = models.NewIdentityVerification(username, action)
|
||||
func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) {
|
||||
verification = models.NewIdentityVerification(username, action, ctx.Ctx.RemoteIP())
|
||||
|
||||
verification.ExpiresAt = expiresAt
|
||||
|
||||
claims := verification.ToIdentityVerificationClaim()
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
ss, _ := token.SignedString([]byte(secret))
|
||||
ss, _ := token.SignedString([]byte(ctx.Ctx.Configuration.JWTSecret))
|
||||
|
||||
return ss, verification
|
||||
}
|
||||
|
||||
func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissing() {
|
||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", ActionU2FRegistration,
|
||||
token, verification := createToken(s.mock, "john", ActionU2FRegistration,
|
||||
time.Now().Add(1*time.Minute))
|
||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
|
@ -57,7 +57,7 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissi
|
|||
Return(true, nil)
|
||||
|
||||
s.mock.StorageMock.EXPECT().
|
||||
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
||||
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
|
||||
Return(nil)
|
||||
|
||||
SecondFactorU2FIdentityFinish(s.mock.Ctx)
|
||||
|
@ -68,7 +68,7 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedProtoIsMissi
|
|||
|
||||
func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissing() {
|
||||
s.mock.Ctx.Request.Header.Add("X-Forwarded-Proto", "http")
|
||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", ActionU2FRegistration,
|
||||
token, verification := createToken(s.mock, "john", ActionU2FRegistration,
|
||||
time.Now().Add(1*time.Minute))
|
||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
|
@ -77,7 +77,7 @@ func (s *HandlerRegisterU2FStep1Suite) TestShouldRaiseWhenXForwardedHostIsMissin
|
|||
Return(true, nil)
|
||||
|
||||
s.mock.StorageMock.EXPECT().
|
||||
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
||||
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
|
||||
Return(nil)
|
||||
|
||||
SecondFactorU2FIdentityFinish(s.mock.Ctx)
|
||||
|
|
|
@ -97,7 +97,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldAutoSelect() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeDuo,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
})).
|
||||
Return(nil)
|
||||
|
||||
|
@ -286,7 +286,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldUseInvalidMethodAndAutoSelect() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeDuo,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
})).
|
||||
Return(nil)
|
||||
|
||||
|
@ -414,7 +414,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldCallDuoAPIAndDenyAccess() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeDuo,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
})).
|
||||
Return(nil)
|
||||
|
||||
|
@ -497,7 +497,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldRedirectUserToDefaultURL() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeDuo,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
})).
|
||||
Return(nil)
|
||||
|
||||
|
@ -546,7 +546,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldNotReturnRedirectURL() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeDuo,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
})).
|
||||
Return(nil)
|
||||
|
||||
|
@ -591,7 +591,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldRedirectUserToSafeTargetURL() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeDuo,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
})).
|
||||
Return(nil)
|
||||
|
||||
|
@ -640,7 +640,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldNotRedirectToUnsafeURL() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeDuo,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
})).
|
||||
Return(nil)
|
||||
|
||||
|
@ -687,7 +687,7 @@ func (s *SecondFactorDuoPostSuite) TestShouldRegenerateSessionForPreventingSessi
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeDuo,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
})).
|
||||
Return(nil)
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToDefaultURL() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeTOTP,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil)
|
||||
|
@ -85,7 +85,7 @@ func (s *HandlerSignTOTPSuite) TestShouldNotReturnRedirectURL() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeTOTP,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil)
|
||||
|
@ -115,7 +115,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRedirectUserToSafeTargetURL() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeTOTP,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
s.mock.TOTPMock.EXPECT().Validate(gomock.Eq("abc"), gomock.Eq(&config)).Return(true, nil)
|
||||
|
@ -146,7 +146,7 @@ func (s *HandlerSignTOTPSuite) TestShouldNotRedirectToUnsafeURL() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeTOTP,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
s.mock.TOTPMock.EXPECT().
|
||||
|
@ -180,7 +180,7 @@ func (s *HandlerSignTOTPSuite) TestShouldRegenerateSessionForPreventingSessionFi
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeTOTP,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
s.mock.TOTPMock.EXPECT().
|
||||
|
|
|
@ -51,7 +51,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRedirectUserToDefaultURL() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeU2F,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
s.mock.Ctx.Configuration.DefaultRedirectionURL = testRedirectionURL
|
||||
|
@ -83,7 +83,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldNotReturnRedirectURL() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeU2F,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
bodyBytes, err := json.Marshal(signU2FRequestBody{
|
||||
|
@ -111,7 +111,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRedirectUserToSafeTargetURL() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeU2F,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
bodyBytes, err := json.Marshal(signU2FRequestBody{
|
||||
|
@ -142,7 +142,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldNotRedirectToUnsafeURL() {
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeU2F,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
bodyBytes, err := json.Marshal(signU2FRequestBody{
|
||||
|
@ -171,7 +171,7 @@ func (s *HandlerSignU2FStep2Suite) TestShouldRegenerateSessionForPreventingSessi
|
|||
Banned: false,
|
||||
Time: s.mock.Clock.Now(),
|
||||
Type: regulation.AuthTypeU2F,
|
||||
RemoteIP: models.NewIPAddressFromString("0.0.0.0"),
|
||||
RemoteIP: models.NewNullIPFromString("0.0.0.0"),
|
||||
}))
|
||||
|
||||
bodyBytes, err := json.Marshal(signU2FRequestBody{
|
||||
|
|
|
@ -27,7 +27,7 @@ func IdentityVerificationStart(args IdentityVerificationStartArgs) RequestHandle
|
|||
return
|
||||
}
|
||||
|
||||
verification := models.NewIdentityVerification(identity.Username, args.ActionClaim)
|
||||
verification := models.NewIdentityVerification(identity.Username, args.ActionClaim, ctx.RemoteIP())
|
||||
|
||||
// Create the claim with the action to sign it.
|
||||
claims := verification.ToIdentityVerificationClaim()
|
||||
|
@ -183,7 +183,7 @@ func IdentityVerificationFinish(args IdentityVerificationFinishArgs, next func(c
|
|||
return
|
||||
}
|
||||
|
||||
err = ctx.Providers.StorageProvider.RemoveIdentityVerification(ctx, claims.ID)
|
||||
err = ctx.Providers.StorageProvider.ConsumeIdentityVerification(ctx, claims.ID, models.NewNullIP(ctx.RemoteIP()))
|
||||
if err != nil {
|
||||
ctx.Error(err, messageOperationFailed)
|
||||
return
|
||||
|
|
|
@ -165,15 +165,15 @@ func (s *IdentityVerificationFinishProcess) TearDownTest() {
|
|||
s.mock.Close()
|
||||
}
|
||||
|
||||
func createToken(secret, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) {
|
||||
verification = models.NewIdentityVerification(username, action)
|
||||
func createToken(ctx *mocks.MockAutheliaCtx, username, action string, expiresAt time.Time) (data string, verification models.IdentityVerification) {
|
||||
verification = models.NewIdentityVerification(username, action, ctx.Ctx.RemoteIP())
|
||||
|
||||
verification.ExpiresAt = expiresAt
|
||||
|
||||
claims := verification.ToIdentityVerificationClaim()
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
ss, _ := token.SignedString([]byte(secret))
|
||||
ss, _ := token.SignedString([]byte(ctx.Ctx.Configuration.JWTSecret))
|
||||
|
||||
return ss, verification
|
||||
}
|
||||
|
@ -203,7 +203,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotProvided()
|
|||
}
|
||||
|
||||
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsNotFoundInDB() {
|
||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", "Login",
|
||||
token, verification := createToken(s.mock, "john", "Login",
|
||||
time.Now().Add(1*time.Minute))
|
||||
|
||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
@ -229,7 +229,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenIsInvalid() {
|
|||
|
||||
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() {
|
||||
args := newArgs(defaultRetriever)
|
||||
token, _ := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", args.ActionClaim,
|
||||
token, _ := createToken(s.mock, "john", args.ActionClaim,
|
||||
time.Now().Add(-1*time.Minute))
|
||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
|
@ -240,7 +240,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenExpired() {
|
|||
}
|
||||
|
||||
func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() {
|
||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "", "",
|
||||
token, verification := createToken(s.mock, "", "",
|
||||
time.Now().Add(1*time.Minute))
|
||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
|
@ -255,7 +255,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongAction() {
|
|||
}
|
||||
|
||||
func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() {
|
||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "harry", "EXP_ACTION",
|
||||
token, verification := createToken(s.mock, "harry", "EXP_ACTION",
|
||||
time.Now().Add(1*time.Minute))
|
||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
|
@ -272,7 +272,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailForWrongUser() {
|
|||
}
|
||||
|
||||
func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemovedFromDB() {
|
||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", "EXP_ACTION",
|
||||
token, verification := createToken(s.mock, "john", "EXP_ACTION",
|
||||
time.Now().Add(1*time.Minute))
|
||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
|
@ -281,7 +281,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemoved
|
|||
Return(true, nil)
|
||||
|
||||
s.mock.StorageMock.EXPECT().
|
||||
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
||||
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
|
||||
Return(fmt.Errorf("cannot remove"))
|
||||
|
||||
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
||||
|
@ -291,7 +291,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldFailIfTokenCannotBeRemoved
|
|||
}
|
||||
|
||||
func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete() {
|
||||
token, verification := createToken(s.mock.Ctx.Configuration.JWTSecret, "john", "EXP_ACTION",
|
||||
token, verification := createToken(s.mock, "john", "EXP_ACTION",
|
||||
time.Now().Add(1*time.Minute))
|
||||
s.mock.Ctx.Request.SetBodyString(fmt.Sprintf("{\"token\":\"%s\"}", token))
|
||||
|
||||
|
@ -300,7 +300,7 @@ func (s *IdentityVerificationFinishProcess) TestShouldReturn200OnFinishComplete(
|
|||
Return(true, nil)
|
||||
|
||||
s.mock.StorageMock.EXPECT().
|
||||
RemoveIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String())).
|
||||
ConsumeIdentityVerification(s.mock.Ctx, gomock.Eq(verification.JTI.String()), gomock.Eq(models.NewNullIP(s.mock.Ctx.RemoteIP()))).
|
||||
Return(nil)
|
||||
|
||||
middlewares.IdentityVerificationFinish(newFinishArgs(), next)(s.mock.Ctx)
|
||||
|
|
|
@ -65,6 +65,20 @@ func (mr *MockStorageMockRecorder) Close() *gomock.Call {
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockStorage)(nil).Close))
|
||||
}
|
||||
|
||||
// ConsumeIdentityVerification mocks base method.
|
||||
func (m *MockStorage) ConsumeIdentityVerification(arg0 context.Context, arg1 string, arg2 models.NullIP) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ConsumeIdentityVerification", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ConsumeIdentityVerification indicates an expected call of ConsumeIdentityVerification.
|
||||
func (mr *MockStorageMockRecorder) ConsumeIdentityVerification(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConsumeIdentityVerification", reflect.TypeOf((*MockStorage)(nil).ConsumeIdentityVerification), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// DeletePreferredDuoDevice mocks base method.
|
||||
func (m *MockStorage) DeletePreferredDuoDevice(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -198,6 +212,21 @@ func (mr *MockStorageMockRecorder) LoadU2FDevice(arg0, arg1 interface{}) *gomock
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevice", reflect.TypeOf((*MockStorage)(nil).LoadU2FDevice), arg0, arg1)
|
||||
}
|
||||
|
||||
// LoadU2FDevices mocks base method.
|
||||
func (m *MockStorage) LoadU2FDevices(arg0 context.Context, arg1, arg2 int) ([]models.U2FDevice, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LoadU2FDevices", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].([]models.U2FDevice)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LoadU2FDevices indicates an expected call of LoadU2FDevices.
|
||||
func (mr *MockStorageMockRecorder) LoadU2FDevices(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadU2FDevices", reflect.TypeOf((*MockStorage)(nil).LoadU2FDevices), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// LoadUserInfo mocks base method.
|
||||
func (m *MockStorage) LoadUserInfo(arg0 context.Context, arg1 string) (models.UserInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -213,20 +242,6 @@ func (mr *MockStorageMockRecorder) LoadUserInfo(arg0, arg1 interface{}) *gomock.
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadUserInfo", reflect.TypeOf((*MockStorage)(nil).LoadUserInfo), arg0, arg1)
|
||||
}
|
||||
|
||||
// RemoveIdentityVerification mocks base method.
|
||||
func (m *MockStorage) RemoveIdentityVerification(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RemoveIdentityVerification", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RemoveIdentityVerification indicates an expected call of RemoveIdentityVerification.
|
||||
func (mr *MockStorageMockRecorder) RemoveIdentityVerification(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveIdentityVerification", reflect.TypeOf((*MockStorage)(nil).RemoveIdentityVerification), arg0, arg1)
|
||||
}
|
||||
|
||||
// SaveIdentityVerification mocks base method.
|
||||
func (m *MockStorage) SaveIdentityVerification(arg0 context.Context, arg1 models.IdentityVerification) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -442,17 +457,3 @@ func (mr *MockStorageMockRecorder) StartupCheck() *gomock.Call {
|
|||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartupCheck", reflect.TypeOf((*MockStorage)(nil).StartupCheck))
|
||||
}
|
||||
|
||||
// UpdateTOTPConfigurationSecret mocks base method.
|
||||
func (m *MockStorage) UpdateTOTPConfigurationSecret(arg0 context.Context, arg1 models.TOTPConfiguration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateTOTPConfigurationSecret", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateTOTPConfigurationSecret indicates an expected call of UpdateTOTPConfigurationSecret.
|
||||
func (mr *MockStorageMockRecorder) UpdateTOTPConfigurationSecret(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTOTPConfigurationSecret", reflect.TypeOf((*MockStorage)(nil).UpdateTOTPConfigurationSecret), arg0, arg1)
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ type AuthenticationAttempt struct {
|
|||
Banned bool `db:"banned"`
|
||||
Username string `db:"username"`
|
||||
Type string `db:"auth_type"`
|
||||
RemoteIP IPAddress `db:"remote_ip"`
|
||||
RemoteIP NullIP `db:"remote_ip"`
|
||||
RequestURI string `db:"request_uri"`
|
||||
RequestMethod string `db:"request_method"`
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
|
@ -8,25 +9,28 @@ import (
|
|||
)
|
||||
|
||||
// NewIdentityVerification creates a new IdentityVerification from a given username and action.
|
||||
func NewIdentityVerification(username, action string) (verification IdentityVerification) {
|
||||
func NewIdentityVerification(username, action string, ip net.IP) (verification IdentityVerification) {
|
||||
return IdentityVerification{
|
||||
JTI: uuid.New(),
|
||||
IssuedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
Action: action,
|
||||
Username: username,
|
||||
IssuedIP: NewIP(ip),
|
||||
}
|
||||
}
|
||||
|
||||
// IdentityVerification represents an identity verification row in the database.
|
||||
type IdentityVerification struct {
|
||||
ID int `db:"id"`
|
||||
JTI uuid.UUID `db:"jti"`
|
||||
IssuedAt time.Time `db:"iat"`
|
||||
ExpiresAt time.Time `db:"exp"`
|
||||
Used *time.Time `db:"used"`
|
||||
Action string `db:"action"`
|
||||
Username string `db:"username"`
|
||||
ID int `db:"id"`
|
||||
JTI uuid.UUID `db:"jti"`
|
||||
IssuedAt time.Time `db:"iat"`
|
||||
IssuedIP IP `db:"issued_ip"`
|
||||
ExpiresAt time.Time `db:"exp"`
|
||||
Action string `db:"action"`
|
||||
Username string `db:"username"`
|
||||
Consumed *time.Time `db:"consumed"`
|
||||
ConsumedIP NullIP `db:"consumed_ip"`
|
||||
}
|
||||
|
||||
// ToIdentityVerificationClaim converts the IdentityVerification into a IdentityVerificationClaim.
|
||||
|
|
|
@ -2,23 +2,71 @@ package models
|
|||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// NewIPAddressFromString converts a string into an IPAddress.
|
||||
func NewIPAddressFromString(ip string) (ipAddress IPAddress) {
|
||||
actualIP := net.ParseIP(ip)
|
||||
return IPAddress{IP: &actualIP}
|
||||
// NewIP easily constructs a new IP.
|
||||
func NewIP(value net.IP) (ip IP) {
|
||||
return IP{IP: value}
|
||||
}
|
||||
|
||||
// IPAddress is a type specific for storage of a net.IP in the database.
|
||||
type IPAddress struct {
|
||||
*net.IP
|
||||
// NewNullIP easily constructs a new NullIP.
|
||||
func NewNullIP(value net.IP) (ip NullIP) {
|
||||
return NullIP{IP: value}
|
||||
}
|
||||
|
||||
// Value is the IPAddress implementation of the databases/sql driver.Valuer.
|
||||
func (ip IPAddress) Value() (value driver.Value, err error) {
|
||||
// NewNullIPFromString easily constructs a new NullIP from a string.
|
||||
func NewNullIPFromString(value string) (ip NullIP) {
|
||||
if value == "" {
|
||||
return ip
|
||||
}
|
||||
|
||||
return NullIP{IP: net.ParseIP(value)}
|
||||
}
|
||||
|
||||
// IP is a type specific for storage of a net.IP in the database which can't be NULL.
|
||||
type IP struct {
|
||||
IP net.IP
|
||||
}
|
||||
|
||||
// Value is the IP implementation of the databases/sql driver.Valuer.
|
||||
func (ip IP) Value() (value driver.Value, err error) {
|
||||
if ip.IP == nil {
|
||||
return nil, errors.New("cannot value nil IP to driver.Value")
|
||||
}
|
||||
|
||||
return driver.Value(ip.IP.String()), nil
|
||||
}
|
||||
|
||||
// Scan is the IP implementation of the sql.Scanner.
|
||||
func (ip *IP) Scan(src interface{}) (err error) {
|
||||
if src == nil {
|
||||
return errors.New("cannot scan nil to type IP")
|
||||
}
|
||||
|
||||
var value string
|
||||
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
value = v
|
||||
default:
|
||||
return fmt.Errorf("invalid type %T for IP %v", src, src)
|
||||
}
|
||||
|
||||
ip.IP = net.ParseIP(value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NullIP is a type specific for storage of a net.IP in the database which can also be NULL.
|
||||
type NullIP struct {
|
||||
IP net.IP
|
||||
}
|
||||
|
||||
// Value is the NullIP implementation of the databases/sql driver.Valuer.
|
||||
func (ip NullIP) Value() (value driver.Value, err error) {
|
||||
if ip.IP == nil {
|
||||
return driver.Value(nil), nil
|
||||
}
|
||||
|
@ -26,8 +74,8 @@ func (ip IPAddress) Value() (value driver.Value, err error) {
|
|||
return driver.Value(ip.IP.String()), nil
|
||||
}
|
||||
|
||||
// Scan is the IPAddress implementation of the sql.Scanner.
|
||||
func (ip *IPAddress) Scan(src interface{}) (err error) {
|
||||
// Scan is the NullIP implementation of the sql.Scanner.
|
||||
func (ip *NullIP) Scan(src interface{}) (err error) {
|
||||
if src == nil {
|
||||
ip.IP = nil
|
||||
return nil
|
||||
|
@ -39,10 +87,10 @@ func (ip *IPAddress) Scan(src interface{}) (err error) {
|
|||
case string:
|
||||
value = v
|
||||
default:
|
||||
return fmt.Errorf("invalid type %T for IPAddress %v", src, src)
|
||||
return fmt.Errorf("invalid type %T for NullIP %v", src, src)
|
||||
}
|
||||
|
||||
*ip.IP = net.ParseIP(value)
|
||||
ip.IP = net.ParseIP(value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ func (r *Regulator) Mark(ctx context.Context, successful, banned bool, username,
|
|||
Banned: banned,
|
||||
Username: username,
|
||||
Type: authType,
|
||||
RemoteIP: models.IPAddress{IP: &remoteIP},
|
||||
RemoteIP: models.NewNullIP(remoteIP),
|
||||
RequestURI: requestURI,
|
||||
RequestMethod: requestMethod,
|
||||
})
|
||||
|
|
|
@ -23,9 +23,11 @@ const (
|
|||
|
||||
// WARNING: Do not change/remove these consts. They are used for Pre1 migrations.
|
||||
const (
|
||||
tablePre1TOTPSecrets = "totp_secrets"
|
||||
tablePre1Config = "config"
|
||||
tablePre1IdentityVerificationTokens = "identity_verification_tokens"
|
||||
tablePre1TOTPSecrets = "totp_secrets"
|
||||
tablePre1IdentityVerificationTokens = "identity_verification_tokens"
|
||||
|
||||
tablePre1Config = "config"
|
||||
|
||||
tableAlphaAuthenticationLogs = "AuthenticationLogs"
|
||||
tableAlphaIdentityVerificationTokens = "IdentityVerificationTokens"
|
||||
tableAlphaPreferences = "Preferences"
|
||||
|
@ -35,6 +37,15 @@ const (
|
|||
tableAlphaU2FDeviceHandles = "U2FDeviceHandles"
|
||||
)
|
||||
|
||||
var tablesPre1 = []string{
|
||||
tablePre1TOTPSecrets,
|
||||
tablePre1IdentityVerificationTokens,
|
||||
|
||||
tableUserPreferences,
|
||||
tableU2FDevices,
|
||||
tableAuthenticationLogs,
|
||||
}
|
||||
|
||||
const (
|
||||
providerAll = "all"
|
||||
providerMySQL = "mysql"
|
||||
|
|
|
@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS authentication_logs (
|
|||
banned BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
auth_type VARCHAR(8) NOT NULL DEFAULT '1FA',
|
||||
remote_ip VARCHAR(47) NULL DEFAULT NULL,
|
||||
remote_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||
request_uri TEXT NOT NULL,
|
||||
request_method VARCHAR(8) NOT NULL DEFAULT '',
|
||||
PRIMARY KEY (id)
|
||||
|
@ -18,10 +18,12 @@ CREATE TABLE IF NOT EXISTS identity_verification (
|
|||
id INTEGER AUTO_INCREMENT,
|
||||
jti CHAR(36),
|
||||
iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
issued_ip VARCHAR(39) NOT NULL,
|
||||
exp TIMESTAMP NOT NULL,
|
||||
used TIMESTAMP NULL DEFAULT NULL,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
action VARCHAR(50) NOT NULL,
|
||||
consumed TIMESTAMP NULL DEFAULT NULL,
|
||||
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE KEY (jti)
|
||||
);
|
||||
|
|
|
@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS authentication_logs (
|
|||
banned BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
auth_type VARCHAR(8) NOT NULL DEFAULT '1FA',
|
||||
remote_ip VARCHAR(47) NULL DEFAULT NULL,
|
||||
remote_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||
request_uri TEXT,
|
||||
request_method VARCHAR(8) NOT NULL DEFAULT '',
|
||||
PRIMARY KEY (id)
|
||||
|
@ -18,10 +18,12 @@ CREATE TABLE IF NOT EXISTS identity_verification (
|
|||
id SERIAL,
|
||||
jti CHAR(36),
|
||||
iat TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
issued_ip VARCHAR(39) NOT NULL,
|
||||
exp TIMESTAMP WITH TIME ZONE NOT NULL,
|
||||
used TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
action VARCHAR(50) NOT NULL,
|
||||
consumed TIMESTAMP WITH TIME ZONE NULL DEFAULT NULL,
|
||||
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE (jti)
|
||||
);
|
||||
|
|
|
@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS authentication_logs (
|
|||
banned BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
auth_type VARCHAR(8) NOT NULL DEFAULT '1FA',
|
||||
remote_ip VARCHAR(47) NULL DEFAULT NULL,
|
||||
remote_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||
request_uri TEXT,
|
||||
request_method VARCHAR(8) NOT NULL DEFAULT '',
|
||||
PRIMARY KEY (id)
|
||||
|
@ -18,10 +18,12 @@ CREATE TABLE IF NOT EXISTS identity_verification (
|
|||
id INTEGER,
|
||||
jti VARCHAR(36),
|
||||
iat TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
issued_ip VARCHAR(39) NOT NULL,
|
||||
exp TIMESTAMP NOT NULL,
|
||||
used TIMESTAMP NULL DEFAULT NULL,
|
||||
username VARCHAR(100) NOT NULL,
|
||||
action VARCHAR(50) NOT NULL,
|
||||
consumed TIMESTAMP NULL DEFAULT NULL,
|
||||
consumed_ip VARCHAR(39) NULL DEFAULT NULL,
|
||||
PRIMARY KEY (id),
|
||||
UNIQUE (jti)
|
||||
);
|
||||
|
|
|
@ -18,17 +18,17 @@ type Provider interface {
|
|||
LoadUserInfo(ctx context.Context, username string) (info models.UserInfo, err error)
|
||||
|
||||
SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error)
|
||||
RemoveIdentityVerification(ctx context.Context, jti string) (err error)
|
||||
ConsumeIdentityVerification(ctx context.Context, jti string, ip models.NullIP) (err error)
|
||||
FindIdentityVerification(ctx context.Context, jti string) (found bool, err error)
|
||||
|
||||
SaveTOTPConfiguration(ctx context.Context, config models.TOTPConfiguration) (err error)
|
||||
DeleteTOTPConfiguration(ctx context.Context, username string) (err error)
|
||||
LoadTOTPConfiguration(ctx context.Context, username string) (config *models.TOTPConfiguration, err error)
|
||||
LoadTOTPConfigurations(ctx context.Context, limit, page int) (configs []models.TOTPConfiguration, err error)
|
||||
UpdateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error)
|
||||
|
||||
SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error)
|
||||
LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error)
|
||||
LoadU2FDevices(ctx context.Context, limit, page int) (devices []models.U2FDevice, err error)
|
||||
|
||||
SavePreferredDuoDevice(ctx context.Context, device models.DuoDevice) (err error)
|
||||
DeletePreferredDuoDevice(ctx context.Context, username string) (err error)
|
||||
|
|
|
@ -34,7 +34,7 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
|||
sqlSelectAuthenticationAttemptsByUsername: fmt.Sprintf(queryFmtSelect1FAAuthenticationLogEntryByUsername, tableAuthenticationLogs),
|
||||
|
||||
sqlInsertIdentityVerification: fmt.Sprintf(queryFmtInsertIdentityVerification, tableIdentityVerification),
|
||||
sqlDeleteIdentityVerification: fmt.Sprintf(queryFmtDeleteIdentityVerification, tableIdentityVerification),
|
||||
sqlConsumeIdentityVerification: fmt.Sprintf(queryFmtConsumeIdentityVerification, tableIdentityVerification),
|
||||
sqlSelectExistsIdentityVerification: fmt.Sprintf(queryFmtSelectExistsIdentityVerification, tableIdentityVerification),
|
||||
|
||||
sqlUpsertTOTPConfig: fmt.Sprintf(queryFmtUpsertTOTPConfiguration, tableTOTPConfigurations),
|
||||
|
@ -45,8 +45,12 @@ func NewSQLProvider(config *schema.Configuration, name, driverName, dataSourceNa
|
|||
sqlUpdateTOTPConfigSecret: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecret, tableTOTPConfigurations),
|
||||
sqlUpdateTOTPConfigSecretByUsername: fmt.Sprintf(queryFmtUpdateTOTPConfigurationSecretByUsername, tableTOTPConfigurations),
|
||||
|
||||
sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices),
|
||||
sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices),
|
||||
sqlUpsertU2FDevice: fmt.Sprintf(queryFmtUpsertU2FDevice, tableU2FDevices),
|
||||
sqlSelectU2FDevice: fmt.Sprintf(queryFmtSelectU2FDevice, tableU2FDevices),
|
||||
sqlSelectU2FDevices: fmt.Sprintf(queryFmtSelectU2FDevices, tableU2FDevices),
|
||||
|
||||
sqlUpdateU2FDevicePublicKey: fmt.Sprintf(queryFmtUpdateU2FDevicePublicKey, tableU2FDevices),
|
||||
sqlUpdateU2FDevicePublicKeyByUsername: fmt.Sprintf(queryFmtUpdateUpdateU2FDevicePublicKeyByUsername, tableU2FDevices),
|
||||
|
||||
sqlUpsertDuoDevice: fmt.Sprintf(queryFmtUpsertDuoDevice, tableDuoDevices),
|
||||
sqlDeleteDuoDevice: fmt.Sprintf(queryFmtDeleteDuoDevice, tableDuoDevices),
|
||||
|
@ -86,7 +90,7 @@ type SQLProvider struct {
|
|||
|
||||
// Table: identity_verification.
|
||||
sqlInsertIdentityVerification string
|
||||
sqlDeleteIdentityVerification string
|
||||
sqlConsumeIdentityVerification string
|
||||
sqlSelectExistsIdentityVerification string
|
||||
|
||||
// Table: totp_configurations.
|
||||
|
@ -99,8 +103,12 @@ type SQLProvider struct {
|
|||
sqlUpdateTOTPConfigSecretByUsername string
|
||||
|
||||
// Table: u2f_devices.
|
||||
sqlUpsertU2FDevice string
|
||||
sqlSelectU2FDevice string
|
||||
sqlUpsertU2FDevice string
|
||||
sqlSelectU2FDevice string
|
||||
sqlSelectU2FDevices string
|
||||
|
||||
sqlUpdateU2FDevicePublicKey string
|
||||
sqlUpdateU2FDevicePublicKeyByUsername string
|
||||
|
||||
// Table: duo_devices
|
||||
sqlUpsertDuoDevice string
|
||||
|
@ -217,7 +225,7 @@ func (p *SQLProvider) LoadUserInfo(ctx context.Context, username string) (info m
|
|||
// SaveIdentityVerification save an identity verification record to the database.
|
||||
func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification models.IdentityVerification) (err error) {
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlInsertIdentityVerification,
|
||||
verification.JTI, verification.IssuedAt, verification.ExpiresAt,
|
||||
verification.JTI, verification.IssuedAt, verification.IssuedIP, verification.ExpiresAt,
|
||||
verification.Username, verification.Action); err != nil {
|
||||
return fmt.Errorf("error inserting identity verification: %w", err)
|
||||
}
|
||||
|
@ -225,9 +233,9 @@ func (p *SQLProvider) SaveIdentityVerification(ctx context.Context, verification
|
|||
return nil
|
||||
}
|
||||
|
||||
// RemoveIdentityVerification remove an identity verification record from the database.
|
||||
func (p *SQLProvider) RemoveIdentityVerification(ctx context.Context, jti string) (err error) {
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlDeleteIdentityVerification, jti); err != nil {
|
||||
// ConsumeIdentityVerification marks an identity verification record in the database as consumed.
|
||||
func (p *SQLProvider) ConsumeIdentityVerification(ctx context.Context, jti string, ip models.NullIP) (err error) {
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlConsumeIdentityVerification, ip, jti); err != nil {
|
||||
return fmt.Errorf("error updating identity verification: %w", err)
|
||||
}
|
||||
|
||||
|
@ -321,8 +329,7 @@ func (p *SQLProvider) LoadTOTPConfigurations(ctx context.Context, limit, page in
|
|||
return configs, nil
|
||||
}
|
||||
|
||||
// UpdateTOTPConfigurationSecret updates a TOTP configuration secret.
|
||||
func (p *SQLProvider) UpdateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error) {
|
||||
func (p *SQLProvider) updateTOTPConfigurationSecret(ctx context.Context, config models.TOTPConfiguration) (err error) {
|
||||
switch config.ID {
|
||||
case 0:
|
||||
_, err = p.db.ExecContext(ctx, p.sqlUpdateTOTPConfigSecretByUsername, config.Secret, config.Username)
|
||||
|
@ -339,6 +346,10 @@ func (p *SQLProvider) UpdateTOTPConfigurationSecret(ctx context.Context, config
|
|||
|
||||
// SaveU2FDevice saves a registered U2F device.
|
||||
func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice) (err error) {
|
||||
if device.PublicKey, err = p.encrypt(device.PublicKey); err != nil {
|
||||
return fmt.Errorf("error encrypting the U2F device public key: %v", err)
|
||||
}
|
||||
|
||||
if _, err = p.db.ExecContext(ctx, p.sqlUpsertU2FDevice, device.Username, device.Description, device.KeyHandle, device.PublicKey); err != nil {
|
||||
return fmt.Errorf("error upserting U2F device: %v", err)
|
||||
}
|
||||
|
@ -348,9 +359,7 @@ func (p *SQLProvider) SaveU2FDevice(ctx context.Context, device models.U2FDevice
|
|||
|
||||
// LoadU2FDevice loads a U2F device registration for a given username.
|
||||
func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (device *models.U2FDevice, err error) {
|
||||
device = &models.U2FDevice{
|
||||
Username: username,
|
||||
}
|
||||
device = &models.U2FDevice{}
|
||||
|
||||
if err = p.db.GetContext(ctx, device, p.sqlSelectU2FDevice, username); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
|
@ -360,9 +369,64 @@ func (p *SQLProvider) LoadU2FDevice(ctx context.Context, username string) (devic
|
|||
return nil, fmt.Errorf("error selecting U2F device: %w", err)
|
||||
}
|
||||
|
||||
if device.PublicKey, err = p.decrypt(device.PublicKey); err != nil {
|
||||
return nil, fmt.Errorf("error decrypting the U2F device public key: %v", err)
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
// LoadU2FDevices loads U2F device registrations.
|
||||
func (p *SQLProvider) LoadU2FDevices(ctx context.Context, limit, page int) (devices []models.U2FDevice, err error) {
|
||||
rows, err := p.db.QueryxContext(ctx, p.sqlSelectU2FDevices, limit, limit*page)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("error selecting U2F devices: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := rows.Close(); err != nil {
|
||||
p.log.Errorf(logFmtErrClosingConn, err)
|
||||
}
|
||||
}()
|
||||
|
||||
devices = make([]models.U2FDevice, 0, limit)
|
||||
|
||||
var device models.U2FDevice
|
||||
|
||||
for rows.Next() {
|
||||
if err = rows.StructScan(&device); err != nil {
|
||||
return nil, fmt.Errorf("error scanning U2F device to struct: %w", err)
|
||||
}
|
||||
|
||||
if device.PublicKey, err = p.decrypt(device.PublicKey); err != nil {
|
||||
return nil, fmt.Errorf("error decrypting the U2F device public key: %v", err)
|
||||
}
|
||||
|
||||
devices = append(devices, device)
|
||||
}
|
||||
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) updateU2FDevicePublicKey(ctx context.Context, device models.U2FDevice) (err error) {
|
||||
switch device.ID {
|
||||
case 0:
|
||||
_, err = p.db.ExecContext(ctx, p.sqlUpdateU2FDevicePublicKeyByUsername, device.PublicKey, device.Username)
|
||||
default:
|
||||
_, err = p.db.ExecContext(ctx, p.sqlUpdateU2FDevicePublicKey, device.PublicKey, device.ID)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating U2F public key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SavePreferredDuoDevice saves a Duo device.
|
||||
func (p *SQLProvider) SavePreferredDuoDevice(ctx context.Context, device models.DuoDevice) (err error) {
|
||||
_, err = p.db.ExecContext(ctx, p.sqlUpsertDuoDevice, device.Username, device.Device, device.Method)
|
||||
|
|
|
@ -38,13 +38,16 @@ func NewPostgreSQLProvider(config *schema.Configuration) (provider *PostgreSQLPr
|
|||
provider.sqlSelectUserInfo = provider.db.Rebind(provider.sqlSelectUserInfo)
|
||||
provider.sqlSelectExistsIdentityVerification = provider.db.Rebind(provider.sqlSelectExistsIdentityVerification)
|
||||
provider.sqlInsertIdentityVerification = provider.db.Rebind(provider.sqlInsertIdentityVerification)
|
||||
provider.sqlDeleteIdentityVerification = provider.db.Rebind(provider.sqlDeleteIdentityVerification)
|
||||
provider.sqlConsumeIdentityVerification = provider.db.Rebind(provider.sqlConsumeIdentityVerification)
|
||||
provider.sqlSelectTOTPConfig = provider.db.Rebind(provider.sqlSelectTOTPConfig)
|
||||
provider.sqlDeleteTOTPConfig = provider.db.Rebind(provider.sqlDeleteTOTPConfig)
|
||||
provider.sqlSelectTOTPConfigs = provider.db.Rebind(provider.sqlSelectTOTPConfigs)
|
||||
provider.sqlUpdateTOTPConfigSecret = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecret)
|
||||
provider.sqlUpdateTOTPConfigSecretByUsername = provider.db.Rebind(provider.sqlUpdateTOTPConfigSecretByUsername)
|
||||
provider.sqlSelectU2FDevice = provider.db.Rebind(provider.sqlSelectU2FDevice)
|
||||
provider.sqlSelectU2FDevices = provider.db.Rebind(provider.sqlSelectU2FDevices)
|
||||
provider.sqlUpdateU2FDevicePublicKey = provider.db.Rebind(provider.sqlUpdateU2FDevicePublicKey)
|
||||
provider.sqlUpdateU2FDevicePublicKeyByUsername = provider.db.Rebind(provider.sqlUpdateU2FDevicePublicKeyByUsername)
|
||||
provider.sqlSelectDuoDevice = provider.db.Rebind(provider.sqlSelectDuoDevice)
|
||||
provider.sqlDeleteDuoDevice = provider.db.Rebind(provider.sqlDeleteDuoDevice)
|
||||
provider.sqlInsertAuthenticationAttempt = provider.db.Rebind(provider.sqlInsertAuthenticationAttempt)
|
||||
|
|
|
@ -22,6 +22,26 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK
|
|||
|
||||
key := sha256.Sum256([]byte(encryptionKey))
|
||||
|
||||
if err = p.schemaEncryptionChangeKeyTOTP(ctx, tx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = p.schemaEncryptionChangeKeyU2F(ctx, tx, key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil {
|
||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("rollback due to error: %w", err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaEncryptionChangeKeyTOTP(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
|
||||
var configs []models.TOTPConfiguration
|
||||
|
||||
for page := 0; true; page++ {
|
||||
|
@ -42,7 +62,7 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK
|
|||
return fmt.Errorf("rollback due to error: %w", err)
|
||||
}
|
||||
|
||||
if err = p.UpdateTOTPConfigurationSecret(ctx, config); err != nil {
|
||||
if err = p.updateTOTPConfigurationSecret(ctx, config); err != nil {
|
||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
||||
}
|
||||
|
@ -56,15 +76,45 @@ func (p *SQLProvider) SchemaEncryptionChangeKey(ctx context.Context, encryptionK
|
|||
}
|
||||
}
|
||||
|
||||
if err = p.setNewEncryptionCheckValue(ctx, &key, tx); err != nil {
|
||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaEncryptionChangeKeyU2F(ctx context.Context, tx *sqlx.Tx, key [32]byte) (err error) {
|
||||
var devices []models.U2FDevice
|
||||
|
||||
for page := 0; true; page++ {
|
||||
if devices, err = p.LoadU2FDevices(ctx, 10, page); err != nil {
|
||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("rollback due to error: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("rollback due to error: %w", err)
|
||||
for _, device := range devices {
|
||||
if device.PublicKey, err = utils.Encrypt(device.PublicKey, &key); err != nil {
|
||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("rollback due to error: %w", err)
|
||||
}
|
||||
|
||||
if err = p.updateU2FDevicePublicKey(ctx, device); err != nil {
|
||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||
return fmt.Errorf("rollback error %v: rollback due to error: %w", rollbackErr, err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("rollback due to error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(devices) != 10 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SchemaEncryptionCheckKey checks the encryption key configured is valid for the database.
|
||||
|
@ -85,49 +135,12 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
|
|||
}
|
||||
|
||||
if verbose {
|
||||
var (
|
||||
config models.TOTPConfiguration
|
||||
row int
|
||||
invalid int
|
||||
total int
|
||||
)
|
||||
|
||||
pageSize := 10
|
||||
|
||||
var rows *sqlx.Rows
|
||||
|
||||
for page := 0; true; page++ {
|
||||
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, pageSize, pageSize*page); err != nil {
|
||||
_ = rows.Close()
|
||||
|
||||
return fmt.Errorf("error selecting TOTP configurations: %w", err)
|
||||
}
|
||||
|
||||
row = 0
|
||||
|
||||
for rows.Next() {
|
||||
total++
|
||||
row++
|
||||
|
||||
if err = rows.StructScan(&config); err != nil {
|
||||
_ = rows.Close()
|
||||
return fmt.Errorf("error scanning TOTP configuration to struct: %w", err)
|
||||
}
|
||||
|
||||
if _, err = p.decrypt(config.Secret); err != nil {
|
||||
invalid++
|
||||
}
|
||||
}
|
||||
|
||||
_ = rows.Close()
|
||||
|
||||
if row < pageSize {
|
||||
break
|
||||
}
|
||||
if err = p.schemaEncryptionCheckTOTP(ctx); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
if invalid != 0 {
|
||||
errs = append(errs, fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total))
|
||||
if err = p.schemaEncryptionCheckU2F(ctx); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -148,6 +161,104 @@ func (p *SQLProvider) SchemaEncryptionCheckKey(ctx context.Context, verbose bool
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaEncryptionCheckTOTP(ctx context.Context) (err error) {
|
||||
var (
|
||||
config models.TOTPConfiguration
|
||||
row int
|
||||
invalid int
|
||||
total int
|
||||
)
|
||||
|
||||
pageSize := 10
|
||||
|
||||
var rows *sqlx.Rows
|
||||
|
||||
for page := 0; true; page++ {
|
||||
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectTOTPConfigs, pageSize, pageSize*page); err != nil {
|
||||
_ = rows.Close()
|
||||
|
||||
return fmt.Errorf("error selecting TOTP configurations: %w", err)
|
||||
}
|
||||
|
||||
row = 0
|
||||
|
||||
for rows.Next() {
|
||||
total++
|
||||
row++
|
||||
|
||||
if err = rows.StructScan(&config); err != nil {
|
||||
_ = rows.Close()
|
||||
return fmt.Errorf("error scanning TOTP configuration to struct: %w", err)
|
||||
}
|
||||
|
||||
if _, err = p.decrypt(config.Secret); err != nil {
|
||||
invalid++
|
||||
}
|
||||
}
|
||||
|
||||
_ = rows.Close()
|
||||
|
||||
if row < pageSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if invalid != 0 {
|
||||
return fmt.Errorf("%d of %d total TOTP secrets were invalid", invalid, total)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SQLProvider) schemaEncryptionCheckU2F(ctx context.Context) (err error) {
|
||||
var (
|
||||
device models.U2FDevice
|
||||
row int
|
||||
invalid int
|
||||
total int
|
||||
)
|
||||
|
||||
pageSize := 10
|
||||
|
||||
var rows *sqlx.Rows
|
||||
|
||||
for page := 0; true; page++ {
|
||||
if rows, err = p.db.QueryxContext(ctx, p.sqlSelectU2FDevices, pageSize, pageSize*page); err != nil {
|
||||
_ = rows.Close()
|
||||
|
||||
return fmt.Errorf("error selecting U2F devices: %w", err)
|
||||
}
|
||||
|
||||
row = 0
|
||||
|
||||
for rows.Next() {
|
||||
total++
|
||||
row++
|
||||
|
||||
if err = rows.StructScan(&device); err != nil {
|
||||
_ = rows.Close()
|
||||
return fmt.Errorf("error scanning U2F device to struct: %w", err)
|
||||
}
|
||||
|
||||
if _, err = p.decrypt(device.PublicKey); err != nil {
|
||||
invalid++
|
||||
}
|
||||
}
|
||||
|
||||
_ = rows.Close()
|
||||
|
||||
if row < pageSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if invalid != 0 {
|
||||
return fmt.Errorf("%d of %d total U2F devices were invalid", invalid, total)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p SQLProvider) encrypt(clearText []byte) (cipherText []byte, err error) {
|
||||
return utils.Encrypt(clearText, &p.key)
|
||||
}
|
||||
|
|
|
@ -60,16 +60,16 @@ const (
|
|||
SELECT EXISTS (
|
||||
SELECT id
|
||||
FROM %s
|
||||
WHERE jti = ? AND exp > CURRENT_TIMESTAMP AND used IS NULL
|
||||
WHERE jti = ? AND exp > CURRENT_TIMESTAMP AND consumed IS NULL
|
||||
);`
|
||||
|
||||
queryFmtInsertIdentityVerification = `
|
||||
INSERT INTO %s (jti, iat, exp, username, action)
|
||||
VALUES (?, ?, ?, ?, ?);`
|
||||
INSERT INTO %s (jti, iat, issued_ip, exp, username, action)
|
||||
VALUES (?, ?, ?, ?, ?, ?);`
|
||||
|
||||
queryFmtDeleteIdentityVerification = `
|
||||
queryFmtConsumeIdentityVerification = `
|
||||
UPDATE %s
|
||||
SET used = CURRENT_TIMESTAMP
|
||||
SET consumed = CURRENT_TIMESTAMP, consumed_ip = ?
|
||||
WHERE jti = ?;`
|
||||
)
|
||||
|
||||
|
@ -114,10 +114,26 @@ const (
|
|||
|
||||
const (
|
||||
queryFmtSelectU2FDevice = `
|
||||
SELECT key_handle, public_key
|
||||
SELECT id, username, key_handle, public_key
|
||||
FROM %s
|
||||
WHERE username = ?;`
|
||||
|
||||
queryFmtSelectU2FDevices = `
|
||||
SELECT id, username, key_handle, public_key
|
||||
FROM %s
|
||||
LIMIT ?
|
||||
OFFSET ?;`
|
||||
|
||||
queryFmtUpdateU2FDevicePublicKey = `
|
||||
UPDATE %s
|
||||
SET public_key = ?
|
||||
WHERE id = ?;`
|
||||
|
||||
queryFmtUpdateUpdateU2FDevicePublicKeyByUsername = `
|
||||
UPDATE %s
|
||||
SET public_key = ?
|
||||
WHERE username = ?;`
|
||||
|
||||
queryFmtUpsertU2FDevice = `
|
||||
REPLACE INTO %s (username, description, key_handle, public_key)
|
||||
VALUES (?, ?, ?, ?);`
|
||||
|
|
|
@ -2,6 +2,7 @@ package storage
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
@ -57,9 +58,13 @@ func (p *SQLProvider) SchemaVersion(ctx context.Context) (version int, err error
|
|||
return migration.After, nil
|
||||
}
|
||||
|
||||
if utils.IsStringInSlice(tableUserPreferences, tables) && utils.IsStringInSlice(tablePre1TOTPSecrets, tables) &&
|
||||
utils.IsStringInSlice(tableU2FDevices, tables) && utils.IsStringInSlice(tableAuthenticationLogs, tables) &&
|
||||
utils.IsStringInSlice(tablePre1IdentityVerificationTokens, tables) && !utils.IsStringInSlice(tableMigrations, tables) {
|
||||
var tablesV1 = []string{tableDuoDevices, tableEncryption, tableIdentityVerification, tableMigrations, tableTOTPConfigurations}
|
||||
|
||||
if utils.IsStringSliceContainsAll(tablesPre1, tables) {
|
||||
if utils.IsStringSliceContainsAny(tablesV1, tables) {
|
||||
return -2, errors.New("pre1 schema contains v1 tables it shouldn't contain")
|
||||
}
|
||||
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -267,7 +267,12 @@ func (p *SQLProvider) schemaMigratePre1To1U2F(ctx context.Context) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
devices = append(devices, models.U2FDevice{Username: username, KeyHandle: keyHandle, PublicKey: publicKey})
|
||||
encryptedPublicKey, err := p.encrypt(publicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
devices = append(devices, models.U2FDevice{Username: username, KeyHandle: keyHandle, PublicKey: encryptedPublicKey})
|
||||
}
|
||||
|
||||
for _, device := range devices {
|
||||
|
@ -446,6 +451,11 @@ func (p *SQLProvider) schemaMigrate1ToPre1U2F(ctx context.Context) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
device.PublicKey, err = p.decrypt(device.PublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
devices = append(devices, device)
|
||||
}
|
||||
|
||||
|
|
|
@ -91,6 +91,17 @@ func IsStringSliceContainsAll(needles []string, haystack []string) (inSlice bool
|
|||
return true
|
||||
}
|
||||
|
||||
// IsStringSliceContainsAny checks if the haystack contains any of the strings in the needles.
|
||||
func IsStringSliceContainsAny(needles []string, haystack []string) (inSlice bool) {
|
||||
for _, n := range needles {
|
||||
if IsStringInSlice(n, haystack) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SliceString splits a string s into an array with each item being a max of int d
|
||||
// d = denominator, n = numerator, q = quotient, r = remainder.
|
||||
func SliceString(s string, d int) (array []string) {
|
||||
|
|
|
@ -162,3 +162,12 @@ func TestIsStringSliceContainsAll(t *testing.T) {
|
|||
assert.True(t, IsStringSliceContainsAll(needles, haystackOne))
|
||||
assert.False(t, IsStringSliceContainsAll(needles, haystackTwo))
|
||||
}
|
||||
|
||||
func TestIsStringSliceContainsAny(t *testing.T) {
|
||||
needles := []string{"abc", "123", "xyz"}
|
||||
haystackOne := []string{"tvu", "456", "hij"}
|
||||
haystackTwo := []string{"tvu", "123", "456", "xyz"}
|
||||
|
||||
assert.False(t, IsStringSliceContainsAny(needles, haystackOne))
|
||||
assert.True(t, IsStringSliceContainsAny(needles, haystackTwo))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user