diff --git a/cmd/sops/common/common.go b/cmd/sops/common/common.go index 7a52b8fcd..e684c2f0e 100644 --- a/cmd/sops/common/common.go +++ b/cmd/sops/common/common.go @@ -32,7 +32,7 @@ type DecryptTreeOpts struct { func DecryptTree(opts DecryptTreeOpts) (dataKey []byte, err error) { dataKey, err = opts.Tree.Metadata.GetDataKeyWithKeyServices(opts.KeyServices) if err != nil { - return nil, cli.NewExitError(err.Error(), codes.CouldNotRetrieveKey) + return nil, NewExitError(err, codes.CouldNotRetrieveKey) } computedMac, err := opts.Tree.Decrypt(dataKey, opts.Cipher) if err != nil { @@ -96,6 +96,13 @@ func LoadEncryptedFile(inputStore sops.Store, inputPath string) (*sops.Tree, err return &tree, nil } +func NewExitError(i interface{}, exitCode int) *cli.ExitError { + if userErr, ok := i.(sops.UserError); ok { + return cli.NewExitError(userErr.UserError(), exitCode) + } + return cli.NewExitError(i, exitCode) +} + func IsYAMLFile(path string) bool { return strings.HasSuffix(path, ".yaml") || strings.HasSuffix(path, ".yml") } diff --git a/sops.go b/sops.go index ca824b774..95b4ae35c 100644 --- a/sops.go +++ b/sops.go @@ -435,35 +435,22 @@ func (m Metadata) GetDataKeyWithKeyServices(svcs []keyservice.KeyServiceClient) if m.DataKey != nil { return m.DataKey, nil } - errMsg := "Could not decrypt the data key with any of the master keys:\n" + getDataKeyErr := getDataKeyError{ + RequiredSuccessfulKeyGroups: m.ShamirThreshold, + GroupResults: make([]error, len(m.KeyGroups)), + } var parts [][]byte - for _, group := range m.KeyGroups { - keysLoop: - for _, key := range group { - svcKey := keyservice.KeyFromMasterKey(key) - for _, svc := range svcs { - rsp, err := svc.Decrypt( - context.Background(), - &keyservice.DecryptRequest{ - Ciphertext: key.EncryptedDataKey(), - Key: &svcKey, - }) - if err != nil { - errMsg += fmt.Sprintf("\t%s: %s", key.ToString(), err) - continue - } - parts = append(parts, rsp.Plaintext) - // All keys in a key group encrypt the same part, so as soon - // as we decrypt it successfully with one key, we need to - // proceed with the next group - break keysLoop - } + for i, group := range m.KeyGroups { + part, err := decryptKeyGroup(group, svcs) + if err == nil { + parts = append(parts, part) } + getDataKeyErr.GroupResults[i] = err } var dataKey []byte if len(m.KeyGroups) > 1 { if len(parts) < m.ShamirThreshold { - return nil, fmt.Errorf("not enough parts to recover data key with Shamir: need %d, have %d", m.ShamirThreshold, len(parts)) + return nil, &getDataKeyErr } var err error dataKey, err = shamir.Combine(parts) @@ -472,7 +459,7 @@ func (m Metadata) GetDataKeyWithKeyServices(svcs []keyservice.KeyServiceClient) } } else { if len(parts) != 1 { - return nil, fmt.Errorf("%s", errMsg) + return nil, &getDataKeyErr } dataKey = parts[0] } @@ -481,6 +468,55 @@ func (m Metadata) GetDataKeyWithKeyServices(svcs []keyservice.KeyServiceClient) return dataKey, nil } +// decryptKeyGroup tries to decrypt the contents of the provided KeyGroup with +// any of the MasterKeys in the KeyGroup with any of the provided key services, +// returning as soon as one key service succeeds. +func decryptKeyGroup(group KeyGroup, svcs []keyservice.KeyServiceClient) ([]byte, error) { + var keyErrs []error + for _, key := range group { + part, err := decryptKey(key, svcs) + if err != nil { + keyErrs = append(keyErrs, err) + } else { + return part, nil + } + } + return nil, decryptKeyErrors(keyErrs) +} + +// decryptKey tries to decrypt the contents of the provided MasterKey with any +// of the key services, returning as soon as one key service succeeds. +func decryptKey(key keys.MasterKey, svcs []keyservice.KeyServiceClient) ([]byte, error) { + svcKey := keyservice.KeyFromMasterKey(key) + var part []byte = nil + decryptErr := decryptKeyError{ + keyName: key.ToString(), + } + for _, svc := range svcs { + // All keys in a key group encrypt the same part, so as soon + // as we decrypt it successfully with one key, we need to + // proceed with the next group + var err error + if part == nil { + var rsp *keyservice.DecryptResponse + rsp, err = svc.Decrypt( + context.Background(), + &keyservice.DecryptRequest{ + Ciphertext: key.EncryptedDataKey(), + Key: &svcKey, + }) + if err == nil { + part = rsp.Plaintext + } + } + decryptErr.errs = append(decryptErr.errs, err) + } + if part != nil { + return part, nil + } + return nil, &decryptErr +} + // GetDataKey retrieves the data key from the first MasterKey in the Metadata's KeySources that's able to return it, // using the local KeyService func (m Metadata) GetDataKey() ([]byte, error) { diff --git a/usererrors.go b/usererrors.go new file mode 100644 index 000000000..f0f10813f --- /dev/null +++ b/usererrors.go @@ -0,0 +1,162 @@ +package sops + +import ( + "fmt" + "io/ioutil" + "strings" + + "github.com/fatih/color" + "github.com/goware/prefixer" + wordwrap "github.com/mitchellh/go-wordwrap" +) + +// UserError is a well-formatted error for the purpose of being displayed to +// the end user. +type UserError interface { + error + UserError() string +} + +var statusSuccess = color.New(color.FgGreen).Sprint("SUCCESS") +var statusFailed = color.New(color.FgRed).Sprint("FAILED") + +type getDataKeyError struct { + RequiredSuccessfulKeyGroups int + GroupResults []error +} + +func (err *getDataKeyError) successfulKeyGroups() int { + n := 0 + for _, r := range err.GroupResults { + if r == nil { + n++ + } + } + return n +} + +func (err *getDataKeyError) Error() string { + return fmt.Sprintf("Error getting data key: %d successful groups "+ + "required, got %d", err.RequiredSuccessfulKeyGroups, + err.successfulKeyGroups()) +} + +func (err *getDataKeyError) UserError() string { + var groupErrs []string + for i, res := range err.GroupResults { + groupErr := decryptGroupError{ + err: res, + groupName: fmt.Sprintf("%d", i), + } + groupErrs = append(groupErrs, groupErr.UserError()) + } + var trailer string + if err.RequiredSuccessfulKeyGroups == 0 { + trailer = "Recovery failed because no master key was able to decrypt " + + "the file. In order for SOPS to recover the file, at least one key " + + "has to be successful, but none were." + } else { + trailer = fmt.Sprintf("Recovery failed because the file was "+ + "encrypted with a Shamir threshold of %d, but only %d part(s) "+ + "were successfully recovered, one for each successful key group. "+ + "In order for SOPS to recover the file, at least %d groups have "+ + "to be successful. In order for a group to be successful, "+ + "decryption has to succeed with any of the keys in that key group.", + err.RequiredSuccessfulKeyGroups, err.successfulKeyGroups(), + err.RequiredSuccessfulKeyGroups) + } + trailer = wordwrap.WrapString(trailer, 75) + return fmt.Sprintf("Failed to get the data key required to "+ + "decrypt the SOPS file.\n\n%s\n\n%s", + strings.Join(groupErrs, "\n\n"), trailer) +} + +type decryptGroupError struct { + groupName string + err error +} + +func (r *decryptGroupError) Error() string { + return fmt.Sprintf("could not decryt group %s: %s", r.groupName, r.err) +} + +func (r *decryptGroupError) UserError() string { + var status string + if r.err == nil { + status = statusSuccess + } else { + status = statusFailed + } + header := fmt.Sprintf(`Group %s: %s`, r.groupName, status) + if r.err == nil { + return header + } + message := r.err.Error() + if userError, ok := r.err.(UserError); ok { + message = userError.UserError() + } + reader := prefixer.New(strings.NewReader(message), " ") + errMsg, _ := ioutil.ReadAll(reader) + return fmt.Sprintf("%s\n%s", header, string(errMsg)) +} + +type decryptKeyErrors []error + +func (e decryptKeyErrors) Error() string { + return fmt.Sprintf("error decrypting key: %s", []error(e)) +} + +func (e decryptKeyErrors) UserError() string { + var errStrs []string + for _, err := range []error(e) { + if userErr, ok := err.(UserError); ok { + errStrs = append(errStrs, userErr.UserError()) + } else { + errStrs = append(errStrs, err.Error()) + } + } + return strings.Join(errStrs, "\n\n") +} + +type decryptKeyError struct { + keyName string + errs []error +} + +func (e *decryptKeyError) isSuccessful() bool { + for _, err := range e.errs { + if err == nil { + return true + } + } + return false +} + +func (e *decryptKeyError) Error() string { + return fmt.Sprintf("error decrypting key %s: %s", e.keyName, e.errs) +} + +func (e *decryptKeyError) UserError() string { + var status string + if e.isSuccessful() { + status = statusSuccess + } else { + status = statusFailed + } + header := fmt.Sprintf("%s: %s", e.keyName, status) + if e.isSuccessful() { + return header + } + var errMessages []string + for _, err := range e.errs { + wrappedErr := wordwrap.WrapString(err.Error(), 60) + reader := prefixer.New(strings.NewReader(wrappedErr), " | ") + errMsg, _ := ioutil.ReadAll(reader) + errMsg[0] = '-' + errMessages = append(errMessages, string(errMsg)) + } + joinedMsgs := strings.Join(errMessages, "\n\n") + reader := prefixer.New(strings.NewReader(joinedMsgs), " ") + errMsg, _ := ioutil.ReadAll(reader) + return fmt.Sprintf("%s\n%s", header, string(errMsg)) +}