diff --git a/config/config_test.go b/config/config_test.go index abf2c66a5..805b31699 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -94,6 +94,9 @@ creation_rules: - kms: - arn: foo aws_profile: bar + - arn: foo + context: + baz: bam pgp: - bar gcp_kms: @@ -421,6 +424,7 @@ func TestLoadConfigFile(t *testing.T) { } func TestLoadConfigFileWithGroups(t *testing.T) { + bam := "bam" expected := configFile{ CreationRules: []creationRule{ { @@ -432,7 +436,18 @@ func TestLoadConfigFileWithGroups(t *testing.T) { PathRegex: "", KeyGroups: []keyGroup{ { - KMS: []kmsKey{{Arn: "foo", AwsProfile: "bar"}}, + KMS: []kmsKey{ + { + Arn: "foo", + AwsProfile: "bar", + }, + { + Arn: "foo", + Context: map[string]*string{ + "baz": &bam, + }, + }, + }, PGP: []string{"bar"}, GCPKMS: []gcpKmsKey{{ResourceID: "foo"}}, AzureKV: []azureKVKey{{VaultURL: "https://foo.vault.azure.net", Key: "foo-key", Version: "fooversion"}}, @@ -464,7 +479,7 @@ func TestLoadConfigFileWithMerge(t *testing.T) { assert.Nil(t, err) assert.Equal(t, 2, len(conf.KeyGroups)) assert.Equal(t, 1, len(conf.KeyGroups[0])) - assert.Equal(t, 22, len(conf.KeyGroups[1])) + assert.Equal(t, 23, len(conf.KeyGroups[1])) } func TestLoadConfigFileWithNoMatchingRules(t *testing.T) { @@ -538,9 +553,10 @@ func TestKeyGroupsForFileWithGroups(t *testing.T) { conf, err := parseCreationRuleForFile(parseConfigFile(sampleConfigWithGroups, t), "/conf/path", "whatever", nil) assert.Nil(t, err) assert.Equal(t, "bar", conf.KeyGroups[0][0].ToString()) - assert.Equal(t, "foo", conf.KeyGroups[0][1].ToString()) + assert.Equal(t, "foo||bar", conf.KeyGroups[0][1].ToString()) + assert.Equal(t, "foo|baz:bam", conf.KeyGroups[0][2].ToString()) assert.Equal(t, "qux", conf.KeyGroups[1][0].ToString()) - assert.Equal(t, "baz", conf.KeyGroups[1][1].ToString()) + assert.Equal(t, "baz||foo", conf.KeyGroups[1][1].ToString()) } func TestLoadConfigFileWithUnencryptedSuffix(t *testing.T) { diff --git a/kms/keysource.go b/kms/keysource.go index e1441b492..bf222c8b0 100644 --- a/kms/keysource.go +++ b/kms/keysource.go @@ -11,6 +11,7 @@ import ( "fmt" "os" "regexp" + "sort" "strings" "time" @@ -181,6 +182,38 @@ func ParseKMSContext(in interface{}) map[string]*string { return out } +// kmsContextToString converts a dictionary into a string that can be parsed +// again with ParseKMSContext(). +func kmsContextToString(in map[string]*string) string { + if len(in) == 0 { + return "" + } + + // Collect the keys in a slice and compute the expected length + keys := make([]string, 0, len(in)) + length := 0 + for key := range in { + keys = append(keys, key) + length += len(key) + len(*in[key]) + 2 + } + + // Sort the keys + sort.Strings(keys) + + // Compose a comma-separated string of key-vale pairs + var builder strings.Builder + builder.Grow(length) + for index, key := range keys { + if index > 0 { + builder.WriteString(",") + } + builder.WriteString(key) + builder.WriteByte(':') + builder.WriteString(*in[key]) + } + return builder.String() +} + // CredentialsProvider is a wrapper around aws.CredentialsProvider used for // authentication towards AWS KMS. type CredentialsProvider struct { @@ -278,7 +311,18 @@ func (key *MasterKey) NeedsRotation() bool { // ToString converts the key to a string representation. func (key *MasterKey) ToString() string { - return key.Arn + arnRole := key.Arn + if key.Role != "" { + arnRole = fmt.Sprintf("%s+%s", key.Arn, key.Role) + } + context := kmsContextToString(key.EncryptionContext) + if key.AwsProfile != "" { + return fmt.Sprintf("%s|%s|%s", arnRole, context, key.AwsProfile) + } + if len(key.EncryptionContext) > 0 { + return fmt.Sprintf("%s|%s", arnRole, context) + } + return arnRole } // ToMap converts the MasterKey to a map for serialization purposes. diff --git a/kms/keysource_test.go b/kms/keysource_test.go index a44ab1ec2..da3c6b51e 100644 --- a/kms/keysource_test.go +++ b/kms/keysource_test.go @@ -367,8 +367,38 @@ func TestMasterKey_NeedsRotation(t *testing.T) { } func TestMasterKey_ToString(t *testing.T) { + dummyARNWithRole := fmt.Sprintf("%s+arn:aws:iam::my-role", dummyARN) + + bar := "bar" + bam := "bam" + context := map[string]*string{ + "foo": &bar, + "baz": &bam, + } + key := NewMasterKeyFromArn(dummyARN, nil, "") assert.Equal(t, dummyARN, key.ToString()) + + key = NewMasterKeyFromArn(dummyARNWithRole, nil, "") + assert.Equal(t, dummyARNWithRole, key.ToString()) + + key = NewMasterKeyFromArn(dummyARN, nil, "profile") + assert.Equal(t, fmt.Sprintf("%s||profile", dummyARN), key.ToString()) + + key = NewMasterKeyFromArn(dummyARNWithRole, nil, "profile") + assert.Equal(t, fmt.Sprintf("%s||profile", dummyARNWithRole), key.ToString()) + + key = NewMasterKeyFromArn(dummyARN, context, "") + assert.Equal(t, fmt.Sprintf("%s|baz:bam,foo:bar", dummyARN), key.ToString()) + + key = NewMasterKeyFromArn(dummyARNWithRole, context, "") + assert.Equal(t, fmt.Sprintf("%s|baz:bam,foo:bar", dummyARNWithRole), key.ToString()) + + key = NewMasterKeyFromArn(dummyARN, context, "profile") + assert.Equal(t, fmt.Sprintf("%s|baz:bam,foo:bar|profile", dummyARN), key.ToString()) + + key = NewMasterKeyFromArn(dummyARNWithRole, context, "profile") + assert.Equal(t, fmt.Sprintf("%s|baz:bam,foo:bar|profile", dummyARNWithRole), key.ToString()) } func TestMasterKey_ToMap(t *testing.T) {