diff --git a/config/config.go b/config/config.go index 06ad57370..e2c303b63 100644 --- a/config/config.go +++ b/config/config.go @@ -170,23 +170,78 @@ type destinationRule struct { } type creationRule struct { - PathRegex string `yaml:"path_regex"` - KMS string - AwsProfile string `yaml:"aws_profile"` - Age string `yaml:"age"` - PGP string - GCPKMS string `yaml:"gcp_kms"` - AzureKeyVault string `yaml:"azure_keyvault"` - VaultURI string `yaml:"hc_vault_transit_uri"` - KeyGroups []keyGroup `yaml:"key_groups"` - ShamirThreshold int `yaml:"shamir_threshold"` - UnencryptedSuffix string `yaml:"unencrypted_suffix"` - EncryptedSuffix string `yaml:"encrypted_suffix"` - UnencryptedRegex string `yaml:"unencrypted_regex"` - EncryptedRegex string `yaml:"encrypted_regex"` - UnencryptedCommentRegex string `yaml:"unencrypted_comment_regex"` - EncryptedCommentRegex string `yaml:"encrypted_comment_regex"` - MACOnlyEncrypted bool `yaml:"mac_only_encrypted"` + PathRegex string `yaml:"path_regex"` + KMS interface{} `yaml:"kms"` // string or []string + AwsProfile string `yaml:"aws_profile"` + Age interface{} `yaml:"age"` // string or []string + PGP interface{} `yaml:"pgp"` // string or []string + GCPKMS interface{} `yaml:"gcp_kms"` // string or []string + AzureKeyVault interface{} `yaml:"azure_keyvault"` // string or []string + VaultURI interface{} `yaml:"hc_vault_transit_uri"` // string or []string + KeyGroups []keyGroup `yaml:"key_groups"` + ShamirThreshold int `yaml:"shamir_threshold"` + UnencryptedSuffix string `yaml:"unencrypted_suffix"` + EncryptedSuffix string `yaml:"encrypted_suffix"` + UnencryptedRegex string `yaml:"unencrypted_regex"` + EncryptedRegex string `yaml:"encrypted_regex"` + UnencryptedCommentRegex string `yaml:"unencrypted_comment_regex"` + EncryptedCommentRegex string `yaml:"encrypted_comment_regex"` + MACOnlyEncrypted bool `yaml:"mac_only_encrypted"` +} + +// Helper methods to safely extract keys as []string +func (c *creationRule) GetKMSKeys() []string { + return parseKeyField(c.KMS) +} + +func (c *creationRule) GetAgeKeys() []string { + return parseKeyField(c.Age) +} + +func (c *creationRule) GetPGPKeys() []string { + return parseKeyField(c.PGP) +} + +func (c *creationRule) GetGCPKMSKeys() []string { + return parseKeyField(c.GCPKMS) +} + +func (c *creationRule) GetAzureKeyVaultKeys() []string { + return parseKeyField(c.AzureKeyVault) +} + +func (c *creationRule) GetVaultURIs() []string { + return parseKeyField(c.VaultURI) +} + +// Utility function to handle both string and []string +func parseKeyField(field interface{}) []string { + switch v := field.(type) { + case string: + if v == "" { + return []string{} + } + // Existing CSV parsing logic + keys := strings.Split(v, ",") + result := make([]string, 0, len(keys)) + for _, key := range keys { + trimmed := strings.TrimSpace(key) + if trimmed != "" { // Skip empty strings (fixes trailing comma issue) + result = append(result, trimmed) + } + } + return result + case []interface{}: + result := make([]string, len(v)) + for i, item := range v { + result[i] = fmt.Sprintf("%v", item) + } + return result + case []string: + return v + default: + return []string{} + } } func NewStoresConfig() *StoresConfig { @@ -292,7 +347,7 @@ func getKeyGroupsFromCreationRule(cRule *creationRule, kmsEncryptionContext map[ } else { var keyGroup sops.KeyGroup if cRule.Age != "" { - ageKeys, err := age.MasterKeysFromRecipients(cRule.Age) + ageKeys, err := age.MasterKeysFromRecipients(strings.Join(cRule.GetAgeKeys(), ",")) if err != nil { return nil, err } else { @@ -301,23 +356,23 @@ func getKeyGroupsFromCreationRule(cRule *creationRule, kmsEncryptionContext map[ } } } - for _, k := range pgp.MasterKeysFromFingerprintString(cRule.PGP) { + for _, k := range pgp.MasterKeysFromFingerprintString(strings.Join(cRule.GetPGPKeys(), ",")) { keyGroup = append(keyGroup, k) } - for _, k := range kms.MasterKeysFromArnString(cRule.KMS, kmsEncryptionContext, cRule.AwsProfile) { + for _, k := range kms.MasterKeysFromArnString(strings.Join(cRule.GetKMSKeys(), ","), kmsEncryptionContext, cRule.AwsProfile) { keyGroup = append(keyGroup, k) } - for _, k := range gcpkms.MasterKeysFromResourceIDString(cRule.GCPKMS) { + for _, k := range gcpkms.MasterKeysFromResourceIDString(strings.Join(cRule.GetGCPKMSKeys(), ",")) { keyGroup = append(keyGroup, k) } - azureKeys, err := azkv.MasterKeysFromURLs(cRule.AzureKeyVault) + azureKeys, err := azkv.MasterKeysFromURLs(strings.Join(cRule.GetAzureKeyVaultKeys(), ",")) if err != nil { return nil, err } for _, k := range azureKeys { keyGroup = append(keyGroup, k) } - vaultKeys, err := hcvault.NewMasterKeysFromURIs(cRule.VaultURI) + vaultKeys, err := hcvault.NewMasterKeysFromURIs(strings.Join(cRule.GetVaultURIs(), ",")) if err != nil { return nil, err } diff --git a/config/config_test.go b/config/config_test.go index cb8340a7f..5b550059b 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -718,3 +718,40 @@ func TestLoadConfigFileWithVaultDestinationRules(t *testing.T) { assert.NotNil(t, conf.Destination) assert.Contains(t, conf.Destination.Path("barfoo"), "/v1/kv/barfoo/barfoo") } + +func TestCreationRuleNativeKeyLists(t *testing.T) { + var sampleConfigWithNativeKeyLists = []byte(` +creation_rules: + - path_regex: native_list* + pgp: + - "85D77543B3D624B63CEA9E6DBC17301B491B3F21" # name@email.com + - "FBC7B9E2A4F9289AC0C1D4843D16CEE4A27381B4" # server_XYZ + kms: + - "arn:aws:kms:us-east-1:123456789012:key/12345678-1234-1234-1234-123456789012" + age: + - "age1ql3z7hjy54pw3hyww5ayyfg7zqgvc7w3j2elw8zmrj2kg5sfn9aqmcac8p" + gcp_kms: + - "projects/test-project/locations/global/keyRings/test-ring/cryptoKeys/test-key" + hc_vault_transit_uri: + - "https://vault.example.com:8200/v1/transit/keys/key1" +`) + conf, err := parseCreationRuleForFile(parseConfigFile(sampleConfigWithNativeKeyLists, t), "/conf/path", "native_list_test", nil) + assert.Nil(t, err) + if conf == nil { + t.Fatal("Expected configuration but got nil") + } + + assert.True(t, len(conf.KeyGroups) > 0) + assert.True(t, len(conf.KeyGroups[0]) == 6) + + keyTypeCounts := make(map[string]int) + for _, key := range conf.KeyGroups[0] { + keyTypeCounts[key.TypeToIdentifier()]++ + } + + assert.Equal(t, 2, keyTypeCounts["pgp"]) + assert.Equal(t, 1, keyTypeCounts["kms"]) + assert.Equal(t, 1, keyTypeCounts["age"]) + assert.Equal(t, 1, keyTypeCounts["gcp_kms"]) + assert.Equal(t, 1, keyTypeCounts["hc_vault"]) +}