mirror of
https://github.com/getsops/sops.git
synced 2026-02-05 12:45:21 +01:00
kms: AWS SDK V2, allow creds config, add tests
This updates the AWS SDK for Go to V2, adds extensive test coverage based on a mocking server, and a general tidying of bits of code. The improvements are based on a fork of the key source in the Flux project's kustomize-controller, built due to SOPS' limitation around credential management without relying on runtime environment variables. - AWS SDK has been updated to V2. There are still bits in `publish/` which would need updating to drop the dependency on V1. - It introduces a `CredentialsProvider` type which holds an `aws.CredentialsProvider`, and can be applied to the `MasterKey`. When applied, the provider is used in the AWS client configuration instead of relying on the SDK default (environmental) values. This is most useful when working with SOPS as an SDK, in combination with e.g. a local key service server implementation. - Extensive test coverage. STS session implementation details are not tested due to mocking complexities, but the wiring is. The forked version of this has compatibility tests to ensure it works with current SOPS: -8b7e7ecb1a/internal/sops/awskms/keysource_test.go (L134)-8b7e7ecb1a/internal/sops/awskms/keysource_test.go (L200)Co-authored-by: Sanskar Jaiswal <sanskar.jaiswal@weave.works> Signed-off-by: Hidde Beydals <hello@hidde.co>
This commit is contained in:
12
go.mod
12
go.mod
@@ -10,6 +10,11 @@ require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.5.1
|
||||
github.com/ProtonMail/go-crypto v0.0.0-20220407094043-a94812496cf5
|
||||
github.com/aws/aws-sdk-go v1.43.43
|
||||
github.com/aws/aws-sdk-go-v2 v1.16.4
|
||||
github.com/aws/aws-sdk-go-v2/config v1.15.9
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.12.4
|
||||
github.com/aws/aws-sdk-go-v2/service/kms v1.17.2
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.16.6
|
||||
github.com/blang/semver v3.5.1+incompatible
|
||||
github.com/fatih/color v1.13.0
|
||||
github.com/golang/protobuf v1.5.2
|
||||
@@ -49,6 +54,13 @@ require (
|
||||
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
|
||||
github.com/armon/go-metrics v0.3.10 // indirect
|
||||
github.com/armon/go-radix v1.0.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.3.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.11.7 // indirect
|
||||
github.com/aws/smithy-go v1.11.2 // indirect
|
||||
github.com/cenkalti/backoff v2.2.1+incompatible // indirect
|
||||
github.com/cenkalti/backoff/v3 v3.2.2 // indirect
|
||||
github.com/containerd/continuity v0.2.2 // indirect
|
||||
|
||||
24
go.sum
24
go.sum
@@ -96,6 +96,30 @@ github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI=
|
||||
github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
|
||||
github.com/aws/aws-sdk-go v1.43.43 h1:1L06qzQvl4aC3Skfh5rV7xVhGHjIZoHcqy16NoyQ1o4=
|
||||
github.com/aws/aws-sdk-go v1.43.43/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo=
|
||||
github.com/aws/aws-sdk-go-v2 v1.16.4 h1:swQTEQUyJF/UkEA94/Ga55miiKFoXmm/Zd67XHgmjSg=
|
||||
github.com/aws/aws-sdk-go-v2 v1.16.4/go.mod h1:ytwTPBG6fXTZLxxeeCCWj2/EMYp/xDUgX+OET6TLNNU=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.15.9 h1:TK5yNEnFDQ9iaO04gJS/3Y+eW8BioQiCUafW75/Wc3Q=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.15.9/go.mod h1:rv/l/TbZo67kp99v/3Kb0qV6Fm1KEtKyruEV2GvVfgs=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.12.4 h1:xggwS+qxCukXRVXJBJWQJGyUsvuxGC8+J1kKzv2cxuw=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.12.4/go.mod h1:7g+GGSp7xtR823o1jedxKmqRZGqLdoHQfI4eFasKKxs=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.5 h1:YPxclBeE07HsLQE8vtjC8T2emcTjM9nzqsnDi2fv5UM=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.5/go.mod h1:WAPnuhG5IQ/i6DETFl5NmX3kKqCzw7aau9NHAGcm4QE=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.11 h1:gsqHplNh1DaQunEKZISK56wlpbCg0yKxNVvGWCFuF1k=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.11/go.mod h1:tmUB6jakq5DFNcXsXOA/ZQ7/C8VnSKYkx58OI7Fh79g=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.5 h1:PLFj+M2PgIDHG//hw3T0O0KLI4itVtAjtxrZx4AHPLg=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.5/go.mod h1:fV1AaS2gFc1tM0RCb015FJ0pvWVUfJZANzjwoO4YakM=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.3.12 h1:j0VqrjtgsY1Bx27tD0ysay36/K4kFMWRp9K3ieO9nLU=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.3.12/go.mod h1:00c7+ALdPh4YeEUPXJzyU0Yy01nPGOq2+9rUaz05z9g=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.5 h1:gRW1ZisKc93EWEORNJRvy/ZydF3o6xLSveJHdi1Oa0U=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.5/go.mod h1:ZbkttHXaVn3bBo/wpJbQGiiIWR90eTBUVBrEHUEQlho=
|
||||
github.com/aws/aws-sdk-go-v2/service/kms v1.17.2 h1:g5sAKPf2OyQf6Qk/HmisWJvAbp3+vjfX1d2wLPUXo1Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/kms v1.17.2/go.mod h1:O99LMSMb/hDB0sQ3OI3SV1rMzwVH/g4608bps5k5dr8=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.11.7 h1:suAGD+RyiHWPPihZzY+jw4mCZlOFWgmdjb2AeTenz7c=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.11.7/go.mod h1:TFVe6Rr2joVLsYQ1ABACXgOC6lXip/qpX2x5jWg/A9w=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.16.6 h1:aYToU0/iazkMY67/BYLt3r6/LT/mUtarLAF5mGof1Kg=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.16.6/go.mod h1:rP1rEOKAGZoXp4iGDxSXFvODAtXpm34Egf0lL0eshaQ=
|
||||
github.com/aws/smithy-go v1.11.2 h1:eG/N+CcUMAvsdffgMvjMKwfyDzIkjM6pfxMJ8Mzc6mE=
|
||||
github.com/aws/smithy-go v1.11.2/go.mod h1:3xHYmszWVx2c0kIwQeEVf9uSm4fYZt67FBJnwub1bgM=
|
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
||||
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
|
||||
475
kms/keysource.go
475
kms/keysource.go
@@ -1,10 +1,12 @@
|
||||
/*
|
||||
Package kms contains an implementation of the go.mozilla.org/sops/v3.MasterKey interface that encrypts and decrypts the
|
||||
data key using AWS KMS with the AWS Go SDK.
|
||||
Package kms contains an implementation of the go.mozilla.org/sops/v3.MasterKey
|
||||
interface that encrypts and decrypts the data key using AWS KMS with the SDK
|
||||
for Go V2.
|
||||
*/
|
||||
package kms //import "go.mozilla.org/sops/v3/kms"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -12,116 +14,71 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mozilla.org/sops/v3/logging"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
|
||||
"github.com/aws/aws-sdk-go/service/sts"
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/kms"
|
||||
"github.com/aws/aws-sdk-go-v2/service/sts"
|
||||
"github.com/sirupsen/logrus"
|
||||
"go.mozilla.org/sops/v3/logging"
|
||||
)
|
||||
|
||||
var log *logrus.Logger
|
||||
const (
|
||||
// arnRegex matches an AWS ARN, for example:
|
||||
// "arn:aws:kms:us-west-2:107501996527:key/612d5f0p-p1l3-45e6-aca6-a5b005693a48".
|
||||
arnRegex = `^arn:aws[\w-]*:kms:(.+):[0-9]+:(key|alias)/.+$`
|
||||
// stsSessionRegex matches an AWS STS session name, for example:
|
||||
// "john_s", "sops@42WQm042".
|
||||
stsSessionRegex = "[^a-zA-Z0-9=,.@-_]+"
|
||||
// roleSessionNameLengthLimit is the AWS role session name length limit.
|
||||
roleSessionNameLengthLimit = 64
|
||||
// kmsTTL is the duration after which a MasterKey requires rotation.
|
||||
kmsTTL = time.Hour * 24 * 30 * 6
|
||||
)
|
||||
|
||||
var (
|
||||
// log is the global logger for any AWS KMS MasterKey.
|
||||
log *logrus.Logger
|
||||
// osHostname returns the hostname as reported by the kernel.
|
||||
osHostname = os.Hostname
|
||||
)
|
||||
|
||||
func init() {
|
||||
log = logging.NewLogger("AWSKMS")
|
||||
}
|
||||
|
||||
// this needs to be a global var for unit tests to work (mockKMS redefines
|
||||
// it in keysource_test.go)
|
||||
var kmsSvc kmsiface.KMSAPI
|
||||
var isMocked bool
|
||||
|
||||
// MasterKey is a AWS KMS key used to encrypt and decrypt sops' data key.
|
||||
// MasterKey is an AWS KMS key used to encrypt and decrypt SOPS' data key using
|
||||
// AWS SDK for Go V2.
|
||||
type MasterKey struct {
|
||||
Arn string
|
||||
Role string
|
||||
EncryptedKey string
|
||||
CreationDate time.Time
|
||||
// Arn associated with the AWS KMS key.
|
||||
Arn string
|
||||
// Role ARN used to assume a role through AWS STS.
|
||||
Role string
|
||||
// EncryptedKey stores the data key in it's encrypted form.
|
||||
EncryptedKey string
|
||||
// CreationDate is when this MasterKey was created.
|
||||
CreationDate time.Time
|
||||
// EncryptionContext provides additional context about the data key.
|
||||
// Ref: https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#encrypt_context
|
||||
EncryptionContext map[string]*string
|
||||
AwsProfile string
|
||||
// AwsProfile is the profile to use for loading configuration and credentials.
|
||||
// Ref: https://aws.github.io/aws-sdk-go-v2/docs/configuring-sdk/#specifying-profiles
|
||||
AwsProfile string
|
||||
|
||||
// credentialsProvider is used to configure the AWS client config with
|
||||
// credentials. It can be injected by a (local) keyservice.KeyServiceServer
|
||||
// using CredentialsProvider.ApplyToMasterKey. If nil, the default client is used
|
||||
// which utilizes runtime environmental values.
|
||||
credentialsProvider aws.CredentialsProvider
|
||||
// epResolver can be used to override the endpoint the AWS client resolves
|
||||
// to by default. This is mostly used for testing purposes as it can not be
|
||||
// injected using e.g. an environment variable. The field is not publicly
|
||||
// exposed, nor configurable.
|
||||
epResolver aws.EndpointResolverWithOptions
|
||||
}
|
||||
|
||||
// EncryptedDataKey returns the encrypted data key this master key holds
|
||||
func (key *MasterKey) EncryptedDataKey() []byte {
|
||||
return []byte(key.EncryptedKey)
|
||||
}
|
||||
|
||||
// SetEncryptedDataKey sets the encrypted data key for this master key
|
||||
func (key *MasterKey) SetEncryptedDataKey(enc []byte) {
|
||||
key.EncryptedKey = string(enc)
|
||||
}
|
||||
|
||||
// Encrypt takes a sops data key, encrypts it with KMS and stores the result in the EncryptedKey field
|
||||
func (key *MasterKey) Encrypt(dataKey []byte) error {
|
||||
// isMocked is set by unit test to indicate that the KMS service
|
||||
// has already been initialized. it's ugly, but it works.
|
||||
if kmsSvc == nil || !isMocked {
|
||||
sess, err := key.createSession()
|
||||
if err != nil {
|
||||
log.WithField("arn", key.Arn).Info("Encryption failed")
|
||||
return fmt.Errorf("Failed to create session: %w", err)
|
||||
}
|
||||
kmsSvc = kms.New(sess)
|
||||
}
|
||||
out, err := kmsSvc.Encrypt(&kms.EncryptInput{Plaintext: dataKey, KeyId: &key.Arn, EncryptionContext: key.EncryptionContext})
|
||||
if err != nil {
|
||||
log.WithField("arn", key.Arn).Info("Encryption failed")
|
||||
return fmt.Errorf("Failed to call KMS encryption service: %w", err)
|
||||
}
|
||||
key.EncryptedKey = base64.StdEncoding.EncodeToString(out.CiphertextBlob)
|
||||
log.WithField("arn", key.Arn).Info("Encryption succeeded")
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncryptIfNeeded encrypts the provided sops' data key and encrypts it if it hasn't been encrypted yet
|
||||
func (key *MasterKey) EncryptIfNeeded(dataKey []byte) error {
|
||||
if key.EncryptedKey == "" {
|
||||
return key.Encrypt(dataKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the EncryptedKey field with AWS KMS and returns the result.
|
||||
func (key *MasterKey) Decrypt() ([]byte, error) {
|
||||
k, err := base64.StdEncoding.DecodeString(key.EncryptedKey)
|
||||
if err != nil {
|
||||
log.WithField("arn", key.Arn).Info("Decryption failed")
|
||||
return nil, fmt.Errorf("Error base64-decoding encrypted data key: %s", err)
|
||||
}
|
||||
// isMocked is set by unit test to indicate that the KMS service
|
||||
// has already been initialized. it's ugly, but it works.
|
||||
if kmsSvc == nil || !isMocked {
|
||||
sess, err := key.createSession()
|
||||
if err != nil {
|
||||
log.WithField("arn", key.Arn).Info("Decryption failed")
|
||||
return nil, fmt.Errorf("Error creating AWS session: %w", err)
|
||||
}
|
||||
kmsSvc = kms.New(sess)
|
||||
}
|
||||
decrypted, err := kmsSvc.Decrypt(&kms.DecryptInput{CiphertextBlob: k, EncryptionContext: key.EncryptionContext})
|
||||
if err != nil {
|
||||
log.WithField("arn", key.Arn).Info("Decryption failed")
|
||||
return nil, fmt.Errorf("Error decrypting key: %w", err)
|
||||
}
|
||||
log.WithField("arn", key.Arn).Info("Decryption succeeded")
|
||||
return decrypted.Plaintext, nil
|
||||
}
|
||||
|
||||
// NeedsRotation returns whether the data key needs to be rotated or not.
|
||||
func (key *MasterKey) NeedsRotation() bool {
|
||||
return time.Since(key.CreationDate) > (time.Hour * 24 * 30 * 6)
|
||||
}
|
||||
|
||||
// ToString converts the key to a string representation
|
||||
func (key *MasterKey) ToString() string {
|
||||
return key.Arn
|
||||
}
|
||||
|
||||
// NewMasterKey creates a new MasterKey from an ARN, role and context, setting the creation date to the current date
|
||||
// NewMasterKey creates a new MasterKey from an ARN, role and context, setting
|
||||
// the creation date to the current date.
|
||||
func NewMasterKey(arn string, role string, context map[string]*string) *MasterKey {
|
||||
return &MasterKey{
|
||||
Arn: arn,
|
||||
@@ -131,24 +88,26 @@ func NewMasterKey(arn string, role string, context map[string]*string) *MasterKe
|
||||
}
|
||||
}
|
||||
|
||||
// NewMasterKeyFromArn takes an ARN string and returns a new MasterKey for that ARN
|
||||
// NewMasterKeyFromArn takes an ARN string and returns a new MasterKey for that
|
||||
// ARN.
|
||||
func NewMasterKeyFromArn(arn string, context map[string]*string, awsProfile string) *MasterKey {
|
||||
k := &MasterKey{}
|
||||
key := &MasterKey{}
|
||||
arn = strings.Replace(arn, " ", "", -1)
|
||||
key.Arn = arn
|
||||
roleIndex := strings.Index(arn, "+arn:aws:iam::")
|
||||
if roleIndex > 0 {
|
||||
k.Arn = arn[:roleIndex]
|
||||
k.Role = arn[roleIndex+1:]
|
||||
} else {
|
||||
k.Arn = arn
|
||||
// Overwrite ARN
|
||||
key.Arn = arn[:roleIndex]
|
||||
key.Role = arn[roleIndex+1:]
|
||||
}
|
||||
k.EncryptionContext = context
|
||||
k.CreationDate = time.Now().UTC()
|
||||
k.AwsProfile = awsProfile
|
||||
return k
|
||||
key.EncryptionContext = context
|
||||
key.CreationDate = time.Now().UTC()
|
||||
key.AwsProfile = awsProfile
|
||||
return key
|
||||
}
|
||||
|
||||
// MasterKeysFromArnString takes a comma separated list of AWS KMS ARNs and returns a slice of new MasterKeys for those ARNs
|
||||
// MasterKeysFromArnString takes a comma separated list of AWS KMS ARNs, and
|
||||
// returns a slice of new MasterKeys for those ARNs.
|
||||
func MasterKeysFromArnString(arn string, context map[string]*string, awsProfile string) []*MasterKey {
|
||||
var keys []*MasterKey
|
||||
if arn == "" {
|
||||
@@ -160,88 +119,11 @@ func MasterKeysFromArnString(arn string, context map[string]*string, awsProfile
|
||||
return keys
|
||||
}
|
||||
|
||||
func (key MasterKey) createStsSession(config aws.Config, sess *session.Session) (*session.Session, error) {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stsRoleSessionNameRe, err := regexp.Compile("[^a-zA-Z0-9=,.@-]+")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to compile STS role session name regex: %w", err)
|
||||
}
|
||||
sanitizedHostname := stsRoleSessionNameRe.ReplaceAllString(hostname, "")
|
||||
stsService := sts.New(sess)
|
||||
name := "sops@" + sanitizedHostname
|
||||
|
||||
// Make sure the name is no longer than 64 characters (role session name length limit from AWS)
|
||||
roleSessionNameLengthLimit := 64
|
||||
if len(name) >= roleSessionNameLengthLimit {
|
||||
name = name[:roleSessionNameLengthLimit]
|
||||
}
|
||||
|
||||
out, err := stsService.AssumeRole(&sts.AssumeRoleInput{
|
||||
RoleArn: &key.Role, RoleSessionName: &name})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to assume role %q: %w", key.Role, err)
|
||||
}
|
||||
config.Credentials = credentials.NewStaticCredentials(*out.Credentials.AccessKeyId,
|
||||
*out.Credentials.SecretAccessKey, *out.Credentials.SessionToken)
|
||||
sess, err = session.NewSession(&config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to create new aws session: %w", err)
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (key MasterKey) createSession() (*session.Session, error) {
|
||||
re := regexp.MustCompile(`^arn:aws[\w-]*:kms:(.+):[0-9]+:(key|alias)/.+$`)
|
||||
matches := re.FindStringSubmatch(key.Arn)
|
||||
if matches == nil {
|
||||
return nil, fmt.Errorf("No valid ARN found in %q", key.Arn)
|
||||
}
|
||||
|
||||
config := aws.Config{Region: aws.String(matches[1])}
|
||||
|
||||
opts := session.Options{
|
||||
Profile: key.AwsProfile,
|
||||
Config: config,
|
||||
AssumeRoleTokenProvider: stscreds.StdinTokenProvider,
|
||||
SharedConfigState: session.SharedConfigEnable,
|
||||
}
|
||||
sess, err := session.NewSessionWithOptions(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if key.Role != "" {
|
||||
return key.createStsSession(config, sess)
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
// ToMap converts the MasterKey to a map for serialization purposes
|
||||
func (key MasterKey) ToMap() map[string]interface{} {
|
||||
out := make(map[string]interface{})
|
||||
out["arn"] = key.Arn
|
||||
if key.Role != "" {
|
||||
out["role"] = key.Role
|
||||
}
|
||||
out["created_at"] = key.CreationDate.UTC().Format(time.RFC3339)
|
||||
out["enc"] = key.EncryptedKey
|
||||
if key.EncryptionContext != nil {
|
||||
outcontext := make(map[string]string)
|
||||
for k, v := range key.EncryptionContext {
|
||||
outcontext[k] = *v
|
||||
}
|
||||
out["context"] = outcontext
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ParseKMSContext takes either a KMS context map or a comma-separated list of KMS context key:value pairs and returns a map
|
||||
// ParseKMSContext takes either a KMS context map or a comma-separated list of
|
||||
// KMS context key:value pairs, and returns a map.
|
||||
func ParseKMSContext(in interface{}) map[string]*string {
|
||||
nonStringValueWarning := "Encryption context contains a non-string value, context will not be used"
|
||||
const nonStringValueWarning = "Encryption context contains a non-string value, context will not be used"
|
||||
out := make(map[string]*string)
|
||||
|
||||
switch in := in.(type) {
|
||||
case map[string]interface{}:
|
||||
if len(in) == 0 {
|
||||
@@ -287,3 +169,214 @@ func ParseKMSContext(in interface{}) map[string]*string {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// CredentialsProvider is a wrapper around aws.CredentialsProvider used for
|
||||
// authentication towards AWS KMS.
|
||||
type CredentialsProvider struct {
|
||||
provider aws.CredentialsProvider
|
||||
}
|
||||
|
||||
// NewCredentialsProvider returns a CredentialsProvider object with the provided
|
||||
// aws.CredentialsProvider.
|
||||
func NewCredentialsProvider(cp aws.CredentialsProvider) *CredentialsProvider {
|
||||
return &CredentialsProvider{
|
||||
provider: cp,
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyToMasterKey configures the credentials on the provided key.
|
||||
func (c CredentialsProvider) ApplyToMasterKey(key *MasterKey) {
|
||||
key.credentialsProvider = c.provider
|
||||
}
|
||||
|
||||
// Encrypt takes a SOPS data key, encrypts it with KMS and stores the result
|
||||
// in the EncryptedKey field.
|
||||
func (key *MasterKey) Encrypt(dataKey []byte) error {
|
||||
cfg, err := key.createKMSConfig()
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("arn", key.Arn).Error("Encryption failed")
|
||||
return err
|
||||
}
|
||||
client := kms.NewFromConfig(*cfg)
|
||||
input := &kms.EncryptInput{
|
||||
KeyId: &key.Arn,
|
||||
Plaintext: dataKey,
|
||||
EncryptionContext: stringPointerToStringMap(key.EncryptionContext),
|
||||
}
|
||||
out, err := client.Encrypt(context.TODO(), input)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("arn", key.Arn).Error("Encryption failed")
|
||||
return fmt.Errorf("failed to encrypt sops data key with AWS KMS: %w", err)
|
||||
}
|
||||
key.EncryptedKey = base64.StdEncoding.EncodeToString(out.CiphertextBlob)
|
||||
log.WithField("arn", key.Arn).Info("Encryption succeeded")
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncryptIfNeeded encrypts the provided SOPS data key, if it has not been
|
||||
// encrypted yet.
|
||||
func (key *MasterKey) EncryptIfNeeded(dataKey []byte) error {
|
||||
if key.EncryptedKey == "" {
|
||||
return key.Encrypt(dataKey)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncryptedDataKey returns the encrypted data key this master key holds.
|
||||
func (key *MasterKey) EncryptedDataKey() []byte {
|
||||
return []byte(key.EncryptedKey)
|
||||
}
|
||||
|
||||
// SetEncryptedDataKey sets the encrypted data key for this master key.
|
||||
func (key *MasterKey) SetEncryptedDataKey(enc []byte) {
|
||||
key.EncryptedKey = string(enc)
|
||||
}
|
||||
|
||||
// Decrypt decrypts the EncryptedKey with a newly created AWS KMS config, and
|
||||
// returns the result.
|
||||
func (key *MasterKey) Decrypt() ([]byte, error) {
|
||||
k, err := base64.StdEncoding.DecodeString(key.EncryptedKey)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("arn", key.Arn).Error("Decryption failed")
|
||||
return nil, fmt.Errorf("error base64-decoding encrypted data key: %s", err)
|
||||
}
|
||||
cfg, err := key.createKMSConfig()
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("arn", key.Arn).Error("Decryption failed")
|
||||
return nil, err
|
||||
}
|
||||
client := kms.NewFromConfig(*cfg)
|
||||
input := &kms.DecryptInput{
|
||||
KeyId: &key.Arn,
|
||||
CiphertextBlob: k,
|
||||
EncryptionContext: stringPointerToStringMap(key.EncryptionContext),
|
||||
}
|
||||
decrypted, err := client.Decrypt(context.TODO(), input)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("arn", key.Arn).Error("Decryption failed")
|
||||
return nil, fmt.Errorf("failed to decrypt sops data key with AWS KMS: %w", err)
|
||||
}
|
||||
log.WithField("arn", key.Arn).Info("Decryption succeeded")
|
||||
return decrypted.Plaintext, nil
|
||||
}
|
||||
|
||||
// NeedsRotation returns whether the data key needs to be rotated or not.
|
||||
func (key *MasterKey) NeedsRotation() bool {
|
||||
return time.Since(key.CreationDate) > kmsTTL
|
||||
}
|
||||
|
||||
// ToString converts the key to a string representation.
|
||||
func (key *MasterKey) ToString() string {
|
||||
return key.Arn
|
||||
}
|
||||
|
||||
// ToMap converts the MasterKey to a map for serialization purposes.
|
||||
func (key MasterKey) ToMap() map[string]interface{} {
|
||||
out := make(map[string]interface{})
|
||||
out["arn"] = key.Arn
|
||||
if key.Role != "" {
|
||||
out["role"] = key.Role
|
||||
}
|
||||
out["created_at"] = key.CreationDate.UTC().Format(time.RFC3339)
|
||||
out["enc"] = key.EncryptedKey
|
||||
if key.EncryptionContext != nil {
|
||||
outcontext := make(map[string]string)
|
||||
for k, v := range key.EncryptionContext {
|
||||
outcontext[k] = *v
|
||||
}
|
||||
out["context"] = outcontext
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// createKMSConfig returns an AWS config with the credentialsProvider of the
|
||||
// MasterKey, or the default configuration sources.
|
||||
func (key MasterKey) createKMSConfig() (*aws.Config, error) {
|
||||
re := regexp.MustCompile(arnRegex)
|
||||
matches := re.FindStringSubmatch(key.Arn)
|
||||
if matches == nil {
|
||||
return nil, fmt.Errorf("no valid ARN found in '%s'", key.Arn)
|
||||
}
|
||||
region := matches[1]
|
||||
|
||||
cfg, err := config.LoadDefaultConfig(context.TODO(), func(lo *config.LoadOptions) error {
|
||||
// Use the credentialsProvider if present, otherwise default to reading credentials
|
||||
// from the environment.
|
||||
if key.credentialsProvider != nil {
|
||||
lo.Credentials = key.credentialsProvider
|
||||
}
|
||||
if key.AwsProfile != "" {
|
||||
lo.SharedConfigProfile = key.AwsProfile
|
||||
}
|
||||
lo.Region = region
|
||||
|
||||
// Set the epResolver, if present. Used ONLY for tests.
|
||||
if key.epResolver != nil {
|
||||
lo.EndpointResolverWithOptions = key.epResolver
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not load AWS config: %w", err)
|
||||
}
|
||||
|
||||
if key.Role != "" {
|
||||
return key.createSTSConfig(&cfg)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// createSTSConfig uses AWS STS to assume a role and returns a config
|
||||
// configured with that role's credentials. It returns an error if
|
||||
// it fails to construct a session name, or assume the role.
|
||||
func (key MasterKey) createSTSConfig(config *aws.Config) (*aws.Config, error) {
|
||||
name, err := stsSessionName()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
input := &sts.AssumeRoleInput{
|
||||
RoleArn: &key.Role,
|
||||
RoleSessionName: &name,
|
||||
}
|
||||
|
||||
client := sts.NewFromConfig(*config)
|
||||
out, err := client.AssumeRole(context.TODO(), input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to assume role '%s': %w", key.Role, err)
|
||||
}
|
||||
|
||||
config.Credentials = credentials.NewStaticCredentialsProvider(*out.Credentials.AccessKeyId,
|
||||
*out.Credentials.SecretAccessKey, *out.Credentials.SessionToken,
|
||||
)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// stsSessionName returns the name for the STS session in the format of
|
||||
// `sops@<hostname>`. It sanitizes the hostname with stsSessionRegex, and
|
||||
// truncates to roleSessionNameLengthLimit when it exceeds the limit.
|
||||
func stsSessionName() (string, error) {
|
||||
hostname, err := osHostname()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to construct STS session name: %w", err)
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(stsSessionRegex)
|
||||
sanitizedHostname := re.ReplaceAllString(hostname, "")
|
||||
|
||||
name := "sops@" + sanitizedHostname
|
||||
if len(name) >= roleSessionNameLengthLimit {
|
||||
name = name[:roleSessionNameLengthLimit]
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func stringPointerToStringMap(in map[string]*string) map[string]string {
|
||||
var out = make(map[string]string)
|
||||
for k, v := range in {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
out[k] = *v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -1,75 +1,172 @@
|
||||
package kms
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
logger "log"
|
||||
"os"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/kms"
|
||||
"github.com/ory/dockertest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"go.mozilla.org/sops/v3/kms/mocks"
|
||||
)
|
||||
|
||||
func TestKMS(t *testing.T) {
|
||||
mockKMS := &mocks.KMSAPI{}
|
||||
defer mockKMS.AssertExpectations(t)
|
||||
kmsSvc = mockKMS
|
||||
isMocked = true
|
||||
encryptOutput := &kms.EncryptOutput{}
|
||||
decryptOutput := &kms.DecryptOutput{}
|
||||
mockKMS.On("Encrypt", mock.AnythingOfType("*kms.EncryptInput")).Return(encryptOutput, nil).Run(func(args mock.Arguments) {
|
||||
encryptOutput.CiphertextBlob = args.Get(0).(*kms.EncryptInput).Plaintext
|
||||
var (
|
||||
// testKMSServerURL is the URL of the AWS KMS server running in Docker.
|
||||
// It is loaded by TestMain.
|
||||
testKMSServerURL string
|
||||
// testKMSARN is the ARN on the test AWS KMS server. It is loaded
|
||||
// by TestMain.
|
||||
testKMSARN string
|
||||
)
|
||||
|
||||
const (
|
||||
// dummyARN is a dummy AWS ARN which passes validation.
|
||||
dummyARN = "arn:aws:kms:us-west-2:107501996527:key/612d5f0p-p1l3-45e6-aca6-a5b005693a48"
|
||||
// testLocalKMSImage is a container image repository reference to a mock
|
||||
// version of AWS' Key Management Service.
|
||||
// Ref: https://github.com/nsmithuk/local-kms
|
||||
testLocalKMSImage = "docker.io/nsmithuk/local-kms"
|
||||
// testLocalKMSImage is the container image tag to use.
|
||||
testLocalKMSTag = "3.11.1"
|
||||
)
|
||||
|
||||
// TestMain initializes an AWS KMS server using Docker, writes the HTTP address
|
||||
// to testKMSServerURL, tries to generate a key for encryption-decryption using a
|
||||
// backoff retry approach, and then sets testKMSARN to the ID of the generated key.
|
||||
// It continues to run all the tests, which can make use of the various `test*`
|
||||
// variables.
|
||||
func TestMain(m *testing.M) {
|
||||
// Uses a sensible default on Windows (TCP/HTTP) and Linux/MacOS (socket)
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
logger.Fatalf("could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
// Pull the image, create a container based on it, and run it
|
||||
// resource, err := pool.Run("nsmithuk/local-kms", testLocalKMSVersion, []string{})
|
||||
resource, err := pool.RunWithOptions(&dockertest.RunOptions{
|
||||
Repository: testLocalKMSImage,
|
||||
Tag: testLocalKMSTag,
|
||||
ExposedPorts: []string{"8080"},
|
||||
})
|
||||
mockKMS.On("Decrypt", mock.AnythingOfType("*kms.DecryptInput")).Return(decryptOutput, nil).Run(func(args mock.Arguments) {
|
||||
decryptOutput.Plaintext = args.Get(0).(*kms.DecryptInput).CiphertextBlob
|
||||
})
|
||||
k := MasterKey{Arn: "arn:aws:kms:us-east-1:927034868273:key/e9fc75db-05e9-44c1-9c35-633922bac347", Role: "", EncryptedKey: ""}
|
||||
f := func(x []byte) bool {
|
||||
err := k.Encrypt(x)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
if err != nil {
|
||||
logger.Fatalf("could not start resource: %s", err)
|
||||
}
|
||||
|
||||
purgeResource := func() {
|
||||
if err := pool.Purge(resource); err != nil {
|
||||
logger.Printf("could not purge resource: %s", err)
|
||||
}
|
||||
v, err := k.Decrypt()
|
||||
}
|
||||
|
||||
testKMSServerURL = fmt.Sprintf("http://127.0.0.1:%v", resource.GetPort("8080/tcp"))
|
||||
masterKey := createTestMasterKey(dummyARN)
|
||||
|
||||
kmsClient, err := createTestKMSClient(masterKey)
|
||||
if err != nil {
|
||||
purgeResource()
|
||||
logger.Fatalf("could not create session: %s", err)
|
||||
}
|
||||
|
||||
var key *kms.CreateKeyOutput
|
||||
if err := pool.Retry(func() error {
|
||||
key, err = kmsClient.CreateKey(context.TODO(), &kms.CreateKeyInput{})
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
return bytes.Equal(v, x)
|
||||
return nil
|
||||
}); err != nil {
|
||||
purgeResource()
|
||||
logger.Fatalf("could not create key: %s", err)
|
||||
}
|
||||
config := quick.Config{}
|
||||
if testing.Short() {
|
||||
config.MaxCount = 10
|
||||
|
||||
if key.KeyMetadata.Arn != nil {
|
||||
testKMSARN = *key.KeyMetadata.Arn
|
||||
} else {
|
||||
purgeResource()
|
||||
logger.Fatalf("could not set arn")
|
||||
}
|
||||
if err := quick.Check(f, &config); err != nil {
|
||||
t.Error(err)
|
||||
|
||||
// Run the tests, but only if we succeeded in setting up the AWS KMS server.
|
||||
var code int
|
||||
if err == nil {
|
||||
code = m.Run()
|
||||
}
|
||||
|
||||
// This can't be deferred, as os.Exit simpy does not care
|
||||
if err := pool.Purge(resource); err != nil {
|
||||
logger.Fatalf("could not purge resource: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestKMSKeySourceFromString(t *testing.T) {
|
||||
func TestNewMasterKey(t *testing.T) {
|
||||
var (
|
||||
dummyRole = "a-role"
|
||||
dummyEncryptionContext = map[string]*string{
|
||||
"foo": aws.String("bar"),
|
||||
}
|
||||
)
|
||||
key := NewMasterKey(dummyARN, dummyRole, dummyEncryptionContext)
|
||||
assert.Equal(t, dummyARN, key.Arn)
|
||||
assert.Equal(t, dummyRole, key.Role)
|
||||
assert.Equal(t, dummyEncryptionContext, key.EncryptionContext)
|
||||
assert.NotNil(t, key.CreationDate)
|
||||
}
|
||||
|
||||
func TestNewMasterKeyFromArn(t *testing.T) {
|
||||
t.Run("arn", func(t *testing.T) {
|
||||
var (
|
||||
dummyEncryptionContext = map[string]*string{
|
||||
"foo": aws.String("bar"),
|
||||
}
|
||||
dummyProfile = "a-profile"
|
||||
)
|
||||
key := NewMasterKeyFromArn(dummyARN, dummyEncryptionContext, dummyProfile)
|
||||
assert.Equal(t, dummyARN, key.Arn)
|
||||
assert.Equal(t, dummyEncryptionContext, key.EncryptionContext)
|
||||
assert.Equal(t, dummyProfile, key.AwsProfile)
|
||||
assert.Empty(t, key.Role)
|
||||
assert.NotNil(t, key.CreationDate)
|
||||
})
|
||||
|
||||
t.Run("arn with spaces", func(t *testing.T) {
|
||||
key := NewMasterKeyFromArn(" arn:aws:kms:us-west-2 :107501996527:key/612d5f 0p-p1l3-45e6-aca6-a5b00569 3a48 ", nil, "")
|
||||
assert.Equal(t, "arn:aws:kms:us-west-2:107501996527:key/612d5f0p-p1l3-45e6-aca6-a5b005693a48", key.Arn)
|
||||
})
|
||||
|
||||
t.Run("arn with role", func(t *testing.T) {
|
||||
key := NewMasterKeyFromArn("arn:aws:kms:us-west-2:927034868273:key/fe86dd69-4132-404c-ab86-4269956b4500+arn:aws:iam::927034868273:role/sops-dev-xyz", nil, "")
|
||||
assert.Equal(t, "arn:aws:kms:us-west-2:927034868273:key/fe86dd69-4132-404c-ab86-4269956b4500", key.Arn)
|
||||
assert.Equal(t, "arn:aws:iam::927034868273:role/sops-dev-xyz", key.Role)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMasterKeysFromArnString(t *testing.T) {
|
||||
s := "arn:aws:kms:us-east-1:656532927350:key/920aff2e-c5f1-4040-943a-047fa387b27e+arn:aws:iam::927034868273:role/sops-dev, arn:aws:kms:ap-southeast-1:656532927350:key/9006a8aa-0fa6-4c14-930e-a2dfb916de1d"
|
||||
ks := MasterKeysFromArnString(s, nil, "foo")
|
||||
k1 := ks[0]
|
||||
k2 := ks[1]
|
||||
|
||||
expectedArn1 := "arn:aws:kms:us-east-1:656532927350:key/920aff2e-c5f1-4040-943a-047fa387b27e"
|
||||
expectedRole1 := "arn:aws:iam::927034868273:role/sops-dev"
|
||||
if k1.Arn != expectedArn1 {
|
||||
t.Errorf("ARN mismatch. Expected %s, found %s", expectedArn1, k1.Arn)
|
||||
}
|
||||
if k1.Role != expectedRole1 {
|
||||
t.Errorf("Role mismatch. Expected %s, found %s", expectedRole1, k1.Role)
|
||||
}
|
||||
assert.Equal(t, expectedArn1, k1.Arn)
|
||||
assert.Equal(t, expectedRole1, k1.Role)
|
||||
|
||||
expectedArn2 := "arn:aws:kms:ap-southeast-1:656532927350:key/9006a8aa-0fa6-4c14-930e-a2dfb916de1d"
|
||||
expectedRole2 := ""
|
||||
if k2.Arn != expectedArn2 {
|
||||
t.Errorf("ARN mismatch. Expected %s, found %s", expectedArn2, k2.Arn)
|
||||
}
|
||||
if k2.Role != expectedRole2 {
|
||||
t.Errorf("Role mismatch. Expected empty role, found %s.", k2.Role)
|
||||
}
|
||||
assert.Equal(t, expectedArn2, k2.Arn)
|
||||
assert.Empty(t, k2.Role)
|
||||
}
|
||||
|
||||
func TestParseEncryptionContext(t *testing.T) {
|
||||
func TestParseKMSContext(t *testing.T) {
|
||||
value1 := "value1"
|
||||
value2 := "value2"
|
||||
// map from YAML
|
||||
@@ -113,7 +210,151 @@ func TestParseEncryptionContext(t *testing.T) {
|
||||
assert.Nil(t, ParseKMSContext("key1"))
|
||||
}
|
||||
|
||||
func TestKeyToMap(t *testing.T) {
|
||||
func TestCreds_ApplyToMasterKey(t *testing.T) {
|
||||
creds := NewCredentialsProvider(credentials.NewStaticCredentialsProvider("", "", ""))
|
||||
key := &MasterKey{}
|
||||
creds.ApplyToMasterKey(key)
|
||||
assert.Equal(t, creds.provider, key.credentialsProvider)
|
||||
}
|
||||
|
||||
func TestMasterKey_Encrypt(t *testing.T) {
|
||||
t.Run("encrypt", func(t *testing.T) {
|
||||
key := createTestMasterKey(testKMSARN)
|
||||
dataKey := []byte("UFO sightings")
|
||||
assert.NoError(t, key.Encrypt(dataKey))
|
||||
assert.NotEmpty(t, key.EncryptedKey)
|
||||
|
||||
kmsClient, err := createTestKMSClient(key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
k, err := base64.StdEncoding.DecodeString(key.EncryptedKey)
|
||||
assert.NoError(t, err)
|
||||
|
||||
input := &kms.DecryptInput{
|
||||
CiphertextBlob: k,
|
||||
EncryptionContext: stringPointerToStringMap(key.EncryptionContext),
|
||||
}
|
||||
decrypted, err := kmsClient.Decrypt(context.TODO(), input)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, dataKey, decrypted.Plaintext)
|
||||
})
|
||||
|
||||
t.Run("encrypt error", func(t *testing.T) {
|
||||
// Valid ARN but invalid for test server.
|
||||
key := createTestMasterKey(dummyARN)
|
||||
err := key.Encrypt([]byte("UFO sightings"))
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "failed to encrypt sops data key with AWS KMS")
|
||||
assert.Empty(t, key.EncryptedKey)
|
||||
})
|
||||
|
||||
t.Run("config error", func(t *testing.T) {
|
||||
key := createTestMasterKey("arn:gcp:kms:antartica-north-2::key/45e6-aca6-a5b005693a48")
|
||||
err := key.Encrypt([]byte(""))
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "no valid ARN found")
|
||||
assert.Empty(t, key.EncryptedKey)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMasterKey_EncryptIfNeeded(t *testing.T) {
|
||||
key := createTestMasterKey(testKMSARN)
|
||||
assert.NoError(t, key.EncryptIfNeeded([]byte("data")))
|
||||
|
||||
encryptedKey := key.EncryptedKey
|
||||
assert.NotEmpty(t, encryptedKey)
|
||||
|
||||
assert.NoError(t, key.EncryptIfNeeded([]byte("some other data")))
|
||||
assert.Equal(t, encryptedKey, key.EncryptedKey)
|
||||
}
|
||||
|
||||
func TestMasterKey_EncryptedDataKey(t *testing.T) {
|
||||
key := &MasterKey{EncryptedKey: "some key"}
|
||||
assert.EqualValues(t, key.EncryptedKey, key.EncryptedDataKey())
|
||||
}
|
||||
|
||||
func TestMasterKey_SetEncryptedDataKey(t *testing.T) {
|
||||
key := &MasterKey{}
|
||||
data := []byte("some data")
|
||||
key.SetEncryptedDataKey(data)
|
||||
assert.EqualValues(t, data, key.EncryptedKey)
|
||||
}
|
||||
|
||||
func TestMasterKey_Decrypt(t *testing.T) {
|
||||
t.Run("decrypt", func(t *testing.T) {
|
||||
key := createTestMasterKey(testKMSARN)
|
||||
kmsClient, err := createTestKMSClient(key)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dataKey := []byte("it's always DNS")
|
||||
out, err := kmsClient.Encrypt(context.TODO(), &kms.EncryptInput{
|
||||
Plaintext: dataKey, KeyId: &key.Arn, EncryptionContext: stringPointerToStringMap(key.EncryptionContext),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
key.EncryptedKey = base64.StdEncoding.EncodeToString(out.CiphertextBlob)
|
||||
got, err := key.Decrypt()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, dataKey, got)
|
||||
})
|
||||
|
||||
t.Run("data key error", func(t *testing.T) {
|
||||
key := createTestMasterKey(testKMSARN)
|
||||
key.EncryptedKey = "invalid"
|
||||
got, err := key.Decrypt()
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "error base64-decoding encrypted data key")
|
||||
assert.Nil(t, got)
|
||||
})
|
||||
|
||||
t.Run("decrypt error", func(t *testing.T) {
|
||||
// Valid ARN but invalid for test server.
|
||||
key := createTestMasterKey(dummyARN)
|
||||
key.EncryptedKey = base64.StdEncoding.EncodeToString([]byte("invalid"))
|
||||
got, err := key.Decrypt()
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "failed to decrypt sops data key with AWS KMS")
|
||||
assert.Nil(t, got)
|
||||
})
|
||||
|
||||
t.Run("config error", func(t *testing.T) {
|
||||
key := createTestMasterKey("arn:gcp:kms:antartica-north-2::key/45e6-aca6-a5b005693a48")
|
||||
got, err := key.Decrypt()
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "no valid ARN found")
|
||||
assert.Nil(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMasterKey_EncryptDecrypt_RoundTrip(t *testing.T) {
|
||||
dataKey := []byte("the wheels on the bus go round and round")
|
||||
|
||||
encryptKey := createTestMasterKey(testKMSARN)
|
||||
assert.NoError(t, encryptKey.Encrypt(dataKey))
|
||||
assert.NotEmpty(t, encryptKey.EncryptedKey)
|
||||
|
||||
decryptKey := createTestMasterKey(testKMSARN)
|
||||
decryptKey.EncryptedKey = encryptKey.EncryptedKey
|
||||
|
||||
decryptedData, err := decryptKey.Decrypt()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, dataKey, decryptedData)
|
||||
}
|
||||
|
||||
func TestMasterKey_NeedsRotation(t *testing.T) {
|
||||
key := NewMasterKeyFromArn(dummyARN, nil, "")
|
||||
assert.False(t, key.NeedsRotation())
|
||||
|
||||
key.CreationDate = key.CreationDate.Add(-(kmsTTL + time.Second))
|
||||
assert.True(t, key.NeedsRotation())
|
||||
}
|
||||
|
||||
func TestMasterKey_ToString(t *testing.T) {
|
||||
key := NewMasterKeyFromArn(dummyARN, nil, "")
|
||||
assert.Equal(t, dummyARN, key.ToString())
|
||||
}
|
||||
|
||||
func TestMasterKey_ToMap(t *testing.T) {
|
||||
value1 := "value1"
|
||||
value2 := "value2"
|
||||
key := MasterKey{
|
||||
@@ -137,3 +378,198 @@ func TestKeyToMap(t *testing.T) {
|
||||
},
|
||||
}, key.ToMap())
|
||||
}
|
||||
|
||||
func TestMasterKey_createKMSConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key MasterKey
|
||||
assertFunc func(t *testing.T, cfg *aws.Config, err error)
|
||||
fallback bool
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
key: MasterKey{
|
||||
credentialsProvider: credentials.NewStaticCredentialsProvider("test-id", "test-secret", "test-token"),
|
||||
AwsProfile: "test-profile",
|
||||
Arn: "arn:aws:kms:us-west-2:107501996527:key/612d5f0p-p1l3-45e6-aca6-a5b005693a48",
|
||||
},
|
||||
assertFunc: func(t *testing.T, cfg *aws.Config, err error) {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "us-west-2", cfg.Region)
|
||||
|
||||
creds, err := cfg.Credentials.Retrieve(context.TODO())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test-id", creds.AccessKeyID)
|
||||
assert.Equal(t, "test-secret", creds.SecretAccessKey)
|
||||
assert.Equal(t, "test-token", creds.SessionToken)
|
||||
|
||||
// ConfigSources is a slice of config.Config, which in turn is an interface.
|
||||
// Since we use a LoadOptions object, we assert the type of cfgSrc and then
|
||||
// check if the expected profile is present.
|
||||
for _, cfgSrc := range cfg.ConfigSources {
|
||||
if src, ok := cfgSrc.(config.LoadOptions); ok {
|
||||
assert.Equal(t, "test-profile", src.SharedConfigProfile)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid arn",
|
||||
key: MasterKey{
|
||||
Arn: "arn:gcp:kms:antartica-north-2::key/45e6-aca6-a5b005693a48",
|
||||
},
|
||||
assertFunc: func(t *testing.T, cfg *aws.Config, err error) {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "no valid ARN found")
|
||||
assert.Nil(t, cfg)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "STS config attempt",
|
||||
key: MasterKey{
|
||||
Arn: dummyARN,
|
||||
Role: "role",
|
||||
},
|
||||
assertFunc: func(t *testing.T, cfg *aws.Config, err error) {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "failed to assume role 'role'")
|
||||
assert.Nil(t, cfg)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "client default fallback",
|
||||
key: MasterKey{
|
||||
Arn: "arn:aws:kms:us-west-2:107501996527:key/612d5f0p-p1l3-45e6-aca6-a5b005693a48",
|
||||
},
|
||||
fallback: true,
|
||||
assertFunc: func(t *testing.T, cfg *aws.Config, err error) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
creds, err := cfg.Credentials.Retrieve(context.TODO())
|
||||
assert.Equal(t, "id", creds.AccessKeyID)
|
||||
assert.Equal(t, "secret", creds.SecretAccessKey)
|
||||
assert.Equal(t, "token", creds.SessionToken)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt := tt
|
||||
// Set the environment variables if we want to fallback
|
||||
if tt.fallback {
|
||||
t.Setenv("AWS_ACCESS_KEY_ID", "id")
|
||||
t.Setenv("AWS_SECRET_ACCESS_KEY", "secret")
|
||||
t.Setenv("AWS_SESSION_TOKEN", "token")
|
||||
}
|
||||
cfg, err := tt.key.createKMSConfig()
|
||||
tt.assertFunc(t, cfg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMasterKey_createSTSConfig(t *testing.T) {
|
||||
t.Run("session name error", func(t *testing.T) {
|
||||
defer func() { osHostname = os.Hostname }()
|
||||
osHostname = func() (name string, err error) {
|
||||
err = fmt.Errorf("an error")
|
||||
return
|
||||
}
|
||||
key := NewMasterKeyFromArn(dummyARN, nil, "")
|
||||
cfg, err := key.createSTSConfig(nil)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "failed to construct STS session name")
|
||||
assert.Nil(t, cfg)
|
||||
})
|
||||
|
||||
t.Run("role assumption error", func(t *testing.T) {
|
||||
key := NewMasterKeyFromArn(dummyARN, nil, "")
|
||||
key.Role = "role"
|
||||
got, err := key.createSTSConfig(&aws.Config{})
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "failed to assume role 'role'")
|
||||
assert.Nil(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_stsSessionName(t *testing.T) {
|
||||
t.Run("STS session name", func(t *testing.T) {
|
||||
defer func() { osHostname = os.Hostname }()
|
||||
const mockHostname = "hostname"
|
||||
osHostname = func() (name string, err error) {
|
||||
name = mockHostname
|
||||
return
|
||||
}
|
||||
got, err := stsSessionName()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sops@"+mockHostname, got)
|
||||
})
|
||||
|
||||
t.Run("hostname error", func(t *testing.T) {
|
||||
defer func() { osHostname = os.Hostname }()
|
||||
osHostname = func() (name string, err error) {
|
||||
err = fmt.Errorf("an error")
|
||||
return
|
||||
}
|
||||
got, err := stsSessionName()
|
||||
assert.Error(t, err)
|
||||
assert.ErrorContains(t, err, "failed to construct STS session nam")
|
||||
assert.Empty(t, got)
|
||||
})
|
||||
|
||||
t.Run("replaces with stsSessionRegex", func(t *testing.T) {
|
||||
const mockHostname = "some-hostname"
|
||||
defer func() { osHostname = os.Hostname }()
|
||||
osHostname = func() (name string, err error) {
|
||||
name = mockHostname
|
||||
return
|
||||
}
|
||||
got, err := stsSessionName()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "sops@somehostname", got)
|
||||
})
|
||||
|
||||
t.Run("hostname exceeding roleSessionNameLengthLimit", func(t *testing.T) {
|
||||
const mockHostname = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
defer func() { osHostname = os.Hostname }()
|
||||
osHostname = func() (name string, err error) {
|
||||
name = mockHostname
|
||||
return
|
||||
}
|
||||
got, err := stsSessionName()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, "sops@"+mockHostname, got)
|
||||
assert.Len(t, got, roleSessionNameLengthLimit)
|
||||
})
|
||||
}
|
||||
|
||||
// createTestMasterKey creates a MasterKey with the provided ARN and a dummy
|
||||
// credentials.StaticCredentialsProvider.
|
||||
func createTestMasterKey(arn string) MasterKey {
|
||||
return MasterKey{
|
||||
Arn: arn,
|
||||
credentialsProvider: credentials.NewStaticCredentialsProvider("id", "secret", ""),
|
||||
epResolver: epResolver{},
|
||||
}
|
||||
}
|
||||
|
||||
// createTestKMSClient creates a new client with the
|
||||
// aws.EndpointResolverWithOptions set to epResolver.
|
||||
func createTestKMSClient(key MasterKey) (*kms.Client, error) {
|
||||
cfg, err := key.createKMSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.EndpointResolverWithOptions = epResolver{}
|
||||
return kms.NewFromConfig(*cfg), nil
|
||||
}
|
||||
|
||||
// epResolver is a dummy resolver that points to the local test KMS server.
|
||||
type epResolver struct{}
|
||||
|
||||
// ResolveEndpoint always resolves to testKMSServerURL.
|
||||
func (e epResolver) ResolveEndpoint(_, _ string, _ ...interface{}) (aws.Endpoint, error) {
|
||||
return aws.Endpoint{
|
||||
URL: testKMSServerURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
4058
kms/mocks/KMSAPI.go
4058
kms/mocks/KMSAPI.go
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user