1
0
mirror of https://github.com/lxc/distrobuilder.git synced 2026-02-05 06:45:19 +01:00
Files
distrobuilder/sources/common.go
Chaosoffire 64b60db96c sources: enforce GPG verification across multiple distros
This commit introduces a centralized GPG verification requirement logic
in `sources/common.go` via the `validateGPGRequirements` method.
It ensures consistent security constraints across multiple supported distributions.

Specific security fixes included:
- Rocky Linux: Fixed an issue where the `CHECKSUM` file was downloaded but not GPG verified.
- CentOS: Fixed an issue where 'SHA256SUM' and 'CHECKSUM' files were downloaded but not GPG verified.
- Gentoo: Added GPG requirement validation for the portage snapshot download URL.

Fixes: https://github.com/lxc/distrobuilder/issues/963
Signed-off-by: Chaosoffire <81634128+chaosoffire@users.noreply.github.com>
2026-01-07 16:42:12 +08:00

307 lines
6.8 KiB
Go

package sources
import (
"context"
"fmt"
"hash"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"time"
"github.com/lxc/incus/v6/shared/ioprogress"
incus "github.com/lxc/incus/v6/shared/util"
"github.com/sirupsen/logrus"
"github.com/lxc/distrobuilder/shared"
)
type common struct {
logger *logrus.Logger
definition shared.Definition
rootfsDir string
cacheDir string
sourcesDir string
ctx context.Context
client *http.Client
}
type httpCustomTransport struct{}
func (ct *httpCustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.Header.Get("Accept") == "" {
req.Header.Set("Accept", "*/*")
}
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSHandshakeTimeout = 60 * time.Second
return transport.RoundTrip(req)
}
func (s *common) init(ctx context.Context, logger *logrus.Logger, definition shared.Definition, rootfsDir string, cacheDir string, sourcesDir string) {
s.logger = logger
s.definition = definition
s.rootfsDir = rootfsDir
s.cacheDir = cacheDir
s.sourcesDir = sourcesDir
s.ctx = ctx
transport := &httpCustomTransport{}
s.client = &http.Client{
Transport: transport,
}
}
func (s *common) getTargetDir() string {
dir := filepath.Join(s.sourcesDir, fmt.Sprintf("%s-%s-%s", s.definition.Image.Distribution, s.definition.Image.Release, s.definition.Image.ArchitectureMapped))
dir = strings.ReplaceAll(dir, " ", "")
dir = strings.ToLower(dir)
return dir
}
// DownloadHash downloads a file. If a checksum file is provided, it will try and
// match the hash.
func (s *common) DownloadHash(def shared.DefinitionImage, file, checksum string, hashFunc hash.Hash) (string, error) {
var (
hashes []string
err error
)
destDir := s.getTargetDir()
err = os.MkdirAll(destDir, 0o755)
if err != nil {
return "", err
}
if checksum != "" {
if hashFunc != nil {
hashFunc.Reset()
}
hashLen := 0
if hashFunc != nil {
hashLen = hashFunc.Size() * 2
}
err := shared.Retry(func() error {
hashes, err = downloadChecksum(s.ctx, s.client, destDir, checksum, file, hashFunc, hashLen)
return err
}, 3)
if err != nil {
return "", fmt.Errorf("Error while downloading checksum: %w", err)
}
}
imagePath := filepath.Join(destDir, filepath.Base(file))
stat, err := os.Stat(imagePath)
if err == nil && stat.Size() > 0 {
image, err := os.Open(imagePath)
if err != nil {
return "", err
}
defer image.Close()
if checksum != "" {
if hashFunc != nil {
hashFunc.Reset()
}
_, err = io.Copy(hashFunc, image)
if err != nil {
return "", err
}
result := fmt.Sprintf("%x", hashFunc.Sum(nil))
var hash string
for _, h := range hashes {
if result == h {
hash = h
break
}
}
if hash == "" {
return "", fmt.Errorf("Hash mismatch for %s: %s != %v", imagePath, result, hashes)
}
}
return destDir, nil
}
image, err := os.Create(imagePath)
if err != nil {
return "", err
}
defer image.Close()
progress := func(progress ioprogress.ProgressData) {
fmt.Printf("%s\r", progress.Text)
}
done := make(chan struct{})
defer close(done)
if checksum == "" {
err = shared.Retry(func() error {
_, err = incus.DownloadFileHash(s.ctx, s.client, "distrobuilder", progress, nil, imagePath, file, "", nil, image)
if err != nil {
os.Remove(imagePath)
}
return err
}, 3)
} else {
// Check all file hashes in case multiple have been provided.
err = shared.Retry(func() error {
for _, h := range hashes {
if hashFunc != nil {
hashFunc.Reset()
}
_, err = incus.DownloadFileHash(s.ctx, s.client, "distrobuilder", progress, nil, imagePath, file, h, hashFunc, image)
if err == nil {
break
}
}
if err != nil {
os.Remove(imagePath)
}
return err
}, 3)
}
if err != nil {
return "", err
}
fmt.Println("")
return destDir, nil
}
// GetSignedContent verifies the provided file, and returns its decrypted (plain) content.
func (s *common) GetSignedContent(signedFile string) ([]byte, error) {
keyring, err := s.CreateGPGKeyring()
if err != nil {
return nil, err
}
gpgDir := path.Dir(keyring)
defer os.RemoveAll(gpgDir)
out, err := exec.Command("gpg", "--homedir", gpgDir, "--keyring", keyring,
"--decrypt", signedFile).Output()
if err != nil {
return nil, fmt.Errorf("Failed to get file content: %s: %w", out, err)
}
return out, nil
}
// VerifyFile verifies a file using gpg.
func (s *common) VerifyFile(signedFile, signatureFile string) (bool, error) {
keyring, err := s.CreateGPGKeyring()
if err != nil {
return false, err
}
gpgDir := path.Dir(keyring)
defer os.RemoveAll(gpgDir)
var out strings.Builder
if signatureFile != "" {
err := shared.RunCommand(s.ctx, nil, &out, "gpg", "--homedir", gpgDir, "--keyring", keyring,
"--verify", signatureFile, signedFile)
if err != nil {
return false, fmt.Errorf("Failed to verify: %s: %w", out.String(), err)
}
} else {
err := shared.RunCommand(s.ctx, nil, &out, "gpg", "--homedir", gpgDir, "--keyring", keyring,
"--verify", signedFile)
if err != nil {
return false, fmt.Errorf("Failed to verify: %s: %w", out.String(), err)
}
}
return true, nil
}
// CreateGPGKeyring creates a new GPG keyring.
func (s *common) CreateGPGKeyring() (string, error) {
err := os.MkdirAll(s.getTargetDir(), 0o700)
if err != nil {
return "", err
}
gpgDir, err := os.MkdirTemp(s.getTargetDir(), "gpg.")
if err != nil {
return "", fmt.Errorf("Failed to create gpg directory: %w", err)
}
err = os.MkdirAll(gpgDir, 0o700)
if err != nil {
return "", err
}
var ok bool
for i := 0; i < 3; i++ {
ok, err = recvGPGKeys(s.ctx, gpgDir, s.definition.Source.Keyserver, s.definition.Source.Keys)
if ok {
break
}
time.Sleep(2 * time.Second)
}
if !ok {
return "", err
}
var out strings.Builder
// Export keys to support gpg1 and gpg2
err = shared.RunCommand(s.ctx, nil, &out, "gpg", "--homedir", gpgDir, "--export", "--output",
filepath.Join(gpgDir, "distrobuilder.gpg"))
if err != nil {
os.RemoveAll(gpgDir)
return "", fmt.Errorf("Failed to export keyring: %s: %w", out.String(), err)
}
return filepath.Join(gpgDir, "distrobuilder.gpg"), nil
}
// Checks GPG key requirements.
func (s *common) validateGPGRequirements(u *url.URL) (bool, error) {
hasKeys := len(s.definition.Source.Keys) != 0
if hasKeys {
// GPG keys provided, always verify regardless of protocol
return false, nil
} else if u.Scheme != "https" {
// Force gpg checks when using http
return false, fmt.Errorf("GPG keys are required if downloading from %s", u.Scheme)
} else if !s.definition.Source.SkipVerification {
// HTTPS without keys: warn but allow
s.logger.Warnf("Downloading from %s without GPG keys as no keys were specified", u.Scheme)
return true, nil
}
return s.definition.Source.SkipVerification, nil
}