diff --git a/cmd/sops/main.go b/cmd/sops/main.go index 92efde8d6..939b104ea 100644 --- a/cmd/sops/main.go +++ b/cmd/sops/main.go @@ -460,15 +460,12 @@ func keyGroups(c *cli.Context, file string) ([]sops.KeyGroup, error) { return nil, cli.NewExitError(fmt.Sprintf("Error loading config file: %s", err), exitErrorReadingConfig) } } - kmsString, pgpString, err := yaml.MasterKeyStringsForFile(file, confBytes) - if err == nil { - for _, k := range pgp.MasterKeysFromFingerprintString(pgpString) { - pgpKeys = append(pgpKeys, k) - } - for _, k := range kms.MasterKeysFromArnString(kmsString, kmsEncryptionContext) { - kmsKeys = append(kmsKeys, k) - } + groups, err := yaml.KeyGroupsForFile(file, confBytes, kmsEncryptionContext) + if err != nil { + return nil, err } + log.Printf("Proceeding with key groups: %#v", groups) + return groups, err } return []sops.KeyGroup{append(kmsKeys, pgpKeys...)}, nil } diff --git a/yaml/config.go b/yaml/config.go index 8701a8dd9..b18678e3d 100644 --- a/yaml/config.go +++ b/yaml/config.go @@ -8,6 +8,9 @@ import ( "regexp" "github.com/mozilla-services/yaml" + "go.mozilla.org/sops" + "go.mozilla.org/sops/kms" + "go.mozilla.org/sops/pgp" ) type fileSystem interface { @@ -47,10 +50,16 @@ type configFile struct { CreationRules []creationRule `yaml:"creation_rules"` } +type keyGroup struct { + KMS string + PGP string +} + type creationRule struct { FilenameRegex string `yaml:"filename_regex"` KMS string PGP string + KeyGroups []keyGroup `yaml:"key_groups"` } // Load loads a sops config file into a temporary struct @@ -62,27 +71,50 @@ func (f *configFile) load(bytes []byte) error { return nil } -// MasterKeyStringsForFile returns a comma separated string of KMS ARNs and a comma separated list of PGP fingerprints. If the config bytes are left empty, the function will look for the config file by itself. -func MasterKeyStringsForFile(filepath string, confBytes []byte) (kms, pgp string, err error) { +func KeyGroupsForFile(filepath string, confBytes []byte, kmsEncryptionContext map[string]*string) ([]sops.KeyGroup, error) { + var err error if confBytes == nil { - confPath, err := FindConfigFile(".") + var confPath string + confPath, err = FindConfigFile(".") if err != nil { - return "", "", err + return nil, err } confBytes, err = ioutil.ReadFile(confPath) } if err != nil { - return "", "", fmt.Errorf("Could not read config file: %s", err) + return nil, fmt.Errorf("Could not read config file: %s", err) } conf := configFile{} err = conf.load(confBytes) if err != nil { - return "", "", fmt.Errorf("Error loading config: %s", err) + return nil, fmt.Errorf("Error loading config: %s", err) } + var groups []sops.KeyGroup for _, rule := range conf.CreationRules { if match, _ := regexp.MatchString(rule.FilenameRegex, filepath); match { - return rule.KMS, rule.PGP, nil + if len(rule.KeyGroups) > 0 { + for _, group := range rule.KeyGroups { + var keyGroup sops.KeyGroup + for _, k := range pgp.MasterKeysFromFingerprintString(group.PGP) { + keyGroup = append(keyGroup, k) + } + for _, k := range kms.MasterKeysFromArnString(group.KMS, kmsEncryptionContext) { + keyGroup = append(keyGroup, k) + } + groups = append(groups, keyGroup) + } + } else { + var keyGroup sops.KeyGroup + for _, k := range pgp.MasterKeysFromFingerprintString(rule.PGP) { + keyGroup = append(keyGroup, k) + } + for _, k := range kms.MasterKeysFromArnString(rule.KMS, kmsEncryptionContext) { + keyGroup = append(keyGroup, k) + } + groups = append(groups, keyGroup) + } + return groups, nil } } - return "", "", nil + return nil, nil } diff --git a/yaml/config_test.go b/yaml/config_test.go index afd3dd9b1..dc4ad257e 100644 --- a/yaml/config_test.go +++ b/yaml/config_test.go @@ -1,10 +1,11 @@ package yaml import ( - "github.com/stretchr/testify/assert" "os" "path" "testing" + + "github.com/stretchr/testify/assert" ) type mockFS struct { @@ -51,6 +52,19 @@ creation_rules: pgp: bar `) +var sampleConfigWithGroups = []byte(` +creation_rules: + - filename_regex: foobar* + kms: "1" + pgp: "2" + - filename_regex: "" + key_groups: + - kms: foo + pgp: bar + - kms: baz + pgp: qux +`) + func TestLoadConfigFile(t *testing.T) { expected := configFile{ CreationRules: []creationRule{ @@ -73,13 +87,52 @@ func TestLoadConfigFile(t *testing.T) { assert.Equal(t, expected, conf) } -func TestMasterKeyStringsForFile(t *testing.T) { - kms, pgp, err := MasterKeyStringsForFile("foobar2000", sampleConfig) +func TestLoadConfigFileWithGroups(t *testing.T) { + expected := configFile{ + CreationRules: []creationRule{ + { + FilenameRegex: "foobar*", + KMS: "1", + PGP: "2", + }, + { + FilenameRegex: "", + KeyGroups: []keyGroup{ + { + KMS: "foo", + PGP: "bar", + }, + { + KMS: "baz", + PGP: "qux", + }, + }, + }, + }, + } + + conf := configFile{} + err := conf.load(sampleConfigWithGroups) assert.Equal(t, nil, err) - assert.Equal(t, "1", kms) - assert.Equal(t, "2", pgp) - kms, pgp, err = MasterKeyStringsForFile("whatever", sampleConfig) - assert.Equal(t, nil, err) - assert.Equal(t, "foo", kms) - assert.Equal(t, "bar", pgp) + assert.Equal(t, expected, conf) +} + +func TestKeyGroupsForFile(t *testing.T) { + groups, err := KeyGroupsForFile("foobar2000", sampleConfig, nil) + assert.Equal(t, nil, err) + assert.Equal(t, "2", groups[0][0].ToString()) + assert.Equal(t, "1", groups[0][1].ToString()) + groups, err = KeyGroupsForFile("whatever", sampleConfig, nil) + assert.Equal(t, nil, err) + assert.Equal(t, "bar", groups[0][0].ToString()) + assert.Equal(t, "foo", groups[0][1].ToString()) +} + +func TestKeyGroupsForFileWithGroups(t *testing.T) { + groups, err := KeyGroupsForFile("whatever", sampleConfigWithGroups, nil) + assert.Equal(t, nil, err) + assert.Equal(t, "bar", groups[0][0].ToString()) + assert.Equal(t, "foo", groups[0][1].ToString()) + assert.Equal(t, "qux", groups[1][0].ToString()) + assert.Equal(t, "baz", groups[1][1].ToString()) }