From ce5694a128d81f2c197af41062d79cdb49ccc691 Mon Sep 17 00:00:00 2001 From: Lucas Earl Date: Mon, 21 Jul 2025 16:05:32 -0600 Subject: [PATCH] Addressing felixfontein's latest review. Adds a key type field to the ParseKeyField fn. Signed-off-by: Lucas Earl --- config/config.go | 30 +++++++++++++++++++----------- config/config_test.go | 8 ++++---- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/config/config.go b/config/config.go index 520f291e5..fa2e24062 100644 --- a/config/config.go +++ b/config/config.go @@ -191,31 +191,35 @@ type creationRule struct { // Helper methods to safely extract keys as []string func (c *creationRule) GetKMSKeys() ([]string, error) { - return parseKeyField(c.KMS) + return parseKeyField(c.KMS, "kms") } func (c *creationRule) GetAgeKeys() ([]string, error) { - return parseKeyField(c.Age) + return parseKeyField(c.Age, "age") } func (c *creationRule) GetPGPKeys() ([]string, error) { - return parseKeyField(c.PGP) + return parseKeyField(c.PGP, "pgp") } func (c *creationRule) GetGCPKMSKeys() ([]string, error) { - return parseKeyField(c.GCPKMS) + return parseKeyField(c.GCPKMS, "gcp_kms") } func (c *creationRule) GetAzureKeyVaultKeys() ([]string, error) { - return parseKeyField(c.AzureKeyVault) + return parseKeyField(c.AzureKeyVault, "azure_keyvault") } func (c *creationRule) GetVaultURIs() ([]string, error) { - return parseKeyField(c.VaultURI) + return parseKeyField(c.VaultURI, "hc_vault_transit_uri") } // Utility function to handle both string and []string -func parseKeyField(field interface{}) ([]string, error) { +func parseKeyField(field interface{}, fieldName string) ([]string, error) { + if field == nil { + return []string{}, nil + } + switch v := field.(type) { case string: if v == "" { @@ -234,13 +238,17 @@ func parseKeyField(field interface{}) ([]string, error) { case []interface{}: result := make([]string, len(v)) for i, item := range v { - result[i] = fmt.Sprintf("%v", item) + if str, ok := item.(string); ok { + result[i] = str + } else { + return nil, fmt.Errorf("invalid %s key configuration: expected string in list, got %T", fieldName, item) + } } return result, nil case []string: return v, nil default: - return nil, fmt.Errorf("invalid key field type: expected string, []string, or nil, got %T", field) + return nil, fmt.Errorf("invalid %s key configuration: expected string, []string, or nil, got %T", fieldName, field) } } @@ -359,7 +367,7 @@ func getKeyGroupsFromCreationRule(cRule *creationRule, kmsEncryptionContext map[ return nil, err } - if cRule.Age != "" { + if len(ageKeys) > 0 { ageKeys, err := age.MasterKeysFromRecipients(strings.Join(ageKeys, ",")) if err != nil { return nil, err @@ -390,7 +398,7 @@ func getKeyGroupsFromCreationRule(cRule *creationRule, kmsEncryptionContext map[ for _, k := range gcpkms.MasterKeysFromResourceIDString(strings.Join(gcpkmsKeys, ",")) { keyGroup = append(keyGroup, k) } - azKeys, err := getKeysWithValidation(cRule.GetAzureKeyVaultKeys, "axkeyvault") + azKeys, err := getKeysWithValidation(cRule.GetAzureKeyVaultKeys, "azure_keyvault") if err != nil { return nil, err } diff --git a/config/config_test.go b/config/config_test.go index af08d761e..7c869fedc 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -577,14 +577,14 @@ func TestLoadConfigFileWithInvalidComplicatedRegexp(t *testing.T) { } func TestLoadConfigFileWithComplicatedRegexp(t *testing.T) { - for filePath, _ := range map[string]string{ + for filePath, k := range map[string]string{ "stage/prod/api.yml": "default", "stage/dev/feature-foo.yml": "dev-feature", "stage/dev/api.yml": "dev", } { conf, err := parseCreationRuleForFile(parseConfigFile(sampleConfigWithComplicatedRegexp, t), "/conf/path", filePath, nil) - assert.Nil(t, conf) - assert.ErrorContains(t, err, "invalid age key configuration: invalid key field type: expected string, []string, or nil, got") + assert.Nil(t, err) + assert.Equal(t, k, conf.KeyGroups[0][0].ToString()) } } @@ -741,7 +741,7 @@ creation_rules: t.Fatal("Expected configuration but got nil") } - assert.True(t, len(conf.KeyGroups) > 0) + assert.True(t, len(conf.KeyGroups) == 1) assert.True(t, len(conf.KeyGroups[0]) == 6) keyTypeCounts := make(map[string]int)