1
0
mirror of https://github.com/openshift/installer.git synced 2026-02-05 15:47:14 +01:00
Files
installer/pkg/asset/installconfig/aws/session.go
Thuan Vo d67b14e479 CORS-4055: migrate credential provider check to AWS SDK v2
This commit is an incremental step to migrate AWS API calls
to AWS SDK v2. This focuses on handlers that retrieve the source
or provider of credentials, for example, via shared credential file
and via environment variables.

Note: these logics are to determine whether the credential provider
is static, which is safe to transfer to the cluster as-is in Mint and
Passthrough credentialsMode.
2026-01-27 12:03:27 -08:00

294 lines
9.0 KiB
Go

package aws
import (
"errors"
"fmt"
"os"
"path/filepath"
"sync"
survey "github.com/AlecAivazis/survey/v2"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/sirupsen/logrus"
ini "gopkg.in/ini.v1"
typesaws "github.com/openshift/installer/pkg/types/aws"
"github.com/openshift/installer/pkg/version"
)
var (
onceLoggers = map[string]*sync.Once{
credentials.SharedCredsProviderName: new(sync.Once),
credentials.EnvProviderName: new(sync.Once),
"credentialsFromSession": new(sync.Once),
}
)
// SessionOptions is a function that modifies the provided session.Option.
type SessionOptions func(sess *session.Options)
// WithRegion configures the session.Option to set the AWS region.
func WithRegion(region string) SessionOptions {
return func(sess *session.Options) {
cfg := aws.NewConfig().WithRegion(region)
sess.Config.MergeIn(cfg)
}
}
// WithServiceEndpoints configures the session.Option to use provides services for AWS endpoints.
func WithServiceEndpoints(region string, services []typesaws.ServiceEndpoint) SessionOptions {
return func(sess *session.Options) {
resolver := newAWSResolver(region, services)
cfg := aws.NewConfig().WithEndpointResolver(resolver)
sess.Config.MergeIn(cfg)
}
}
// GetSession returns an AWS session by checking credentials
// and, if no creds are found, asks for them and stores them on disk in a config file
func GetSession() (*session.Session, error) { return GetSessionWithOptions() }
// GetSessionWithOptions returns an AWS session by checking credentials
// and, if no creds are found, asks for them and stores them on disk in a config file
func GetSessionWithOptions(optFuncs ...SessionOptions) (*session.Session, error) {
options := session.Options{
Config: aws.Config{MaxRetries: aws.Int(0)},
SharedConfigState: session.SharedConfigEnable,
}
for _, optFunc := range optFuncs {
optFunc(&options)
}
_, err := getCredentials(options)
if err != nil && errCodeEquals(err, "NoCredentialProviders") {
if err = getUserCredentials(); err != nil {
return nil, err
}
}
if err != nil {
return nil, err
}
ssn := session.Must(session.NewSessionWithOptions(options))
ssn = ssn.Copy(&aws.Config{MaxRetries: aws.Int(25)})
ssn.Handlers.Build.PushBackNamed(request.NamedHandler{
Name: "openshiftInstaller.OpenshiftInstallerUserAgentHandler",
Fn: request.MakeAddToUserAgentHandler("OpenShift/4.x Installer", version.Raw),
})
return ssn, nil
}
func getCredentials(options session.Options) (*credentials.Credentials, error) {
sharedCredentialsProvider := &credentials.SharedCredentialsProvider{}
providers := []credentials.Provider{
&credentials.EnvProvider{},
sharedCredentialsProvider,
}
creds := credentials.NewChainCredentials(providers)
credsValue, err := creds.Get()
if err != nil && errCodeEquals(err, "NoCredentialProviders") {
// getCredentialsFromSession returns credentials derived from a session. A
// session uses the AWS SDK Go chain of providers so may use a provider (e.g.,
// STS) which provides temporary credentials.
return getCredentialsFromSession(options)
}
if err != nil {
return nil, fmt.Errorf("error loading credentials for AWS Provider: %w", err)
}
// log the source of credential provider.
switch credsValue.ProviderName {
case credentials.SharedCredsProviderName:
onceLoggers[credentials.SharedCredsProviderName].Do(func() {
logrus.Infof("Credentials loaded from the %q profile in file %q", sharedCredentialsProvider.Profile, sharedCredentialsProvider.Filename)
})
case credentials.EnvProviderName:
onceLoggers[credentials.EnvProviderName].Do(func() {
logrus.Info("Credentials loaded from default AWS environment variables")
})
}
return creds, nil
}
func getCredentialsFromSession(options session.Options) (*credentials.Credentials, error) {
sess, err := session.NewSessionWithOptions(options)
if err != nil {
if errCodeEquals(err, "NoCredentialProviders") {
return nil, fmt.Errorf("failed to get credentials from session: %w", err)
}
return nil, fmt.Errorf("error creating AWS session: %w", err)
}
creds := sess.Config.Credentials
credsValue, err := sess.Config.Credentials.Get()
if err != nil {
return nil, err
}
onceLoggers["credentialsFromSession"].Do(func() {
logrus.Infof("Credentials loaded from the AWS config using %q provider", credsValue.ProviderName)
})
return creds, nil
}
// errCodeEquals returns true if the error matches all these conditions:
// - err is of type awserr.Error
// - Error.Code() equals code
func errCodeEquals(err error, code string) bool {
var awsErr awserr.Error
if errors.As(err, &awsErr) {
return awsErr.Code() == code
}
return false
}
func getUserCredentials() error {
var keyID string
err := survey.Ask([]*survey.Question{
{
Prompt: &survey.Input{
Message: "AWS Access Key ID",
Help: "The AWS access key ID to use for installation (this is not your username).\nhttps://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_access-keys.html",
},
},
}, &keyID)
if err != nil {
return err
}
var secretKey string
err = survey.Ask([]*survey.Question{
{
Prompt: &survey.Password{
Message: "AWS Secret Access Key",
Help: "The AWS secret access key corresponding to your access key ID (this is not your password).",
},
},
}, &secretKey)
if err != nil {
return err
}
path := defaults.SharedCredentialsFilename()
if env := os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); env != "" {
path = env
}
logrus.Infof("Writing AWS credentials to %q (https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html)", path)
err = os.MkdirAll(filepath.Dir(path), 0700)
if err != nil {
return err
}
creds, err := ini.Load(path)
if err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("failed to load credentials file %s: %w", path, err)
}
creds = ini.Empty()
creds.Section("").Comment = "https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html"
}
profile := os.Getenv("AWS_PROFILE")
if profile == "" {
profile = "default"
}
creds.Section(profile).Key("aws_access_key_id").SetValue(keyID)
creds.Section(profile).Key("aws_secret_access_key").SetValue(secretKey)
tempPath := path + ".tmp"
file, err := os.OpenFile(tempPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600)
if err != nil {
return err
}
defer file.Close()
_, err = creds.WriteTo(file)
if err != nil {
err2 := os.Remove(tempPath)
if err2 != nil {
logrus.Error(fmt.Errorf("failed to remove partially-written credentials file: %w", err2))
}
return err
}
return os.Rename(tempPath, path)
}
type awsResolver struct {
region string
services map[string]typesaws.ServiceEndpoint
// this is a list of known default endpoints for specific regions that would
// otherwise require user to set the service overrides.
// it's a map of region => service => resolved endpoint
// this is only used when the user hasn't specified a override for the service in that region.
defaultEndpoints map[string]map[string]endpoints.ResolvedEndpoint
}
func newAWSResolver(region string, services []typesaws.ServiceEndpoint) *awsResolver {
resolver := &awsResolver{
region: region,
services: make(map[string]typesaws.ServiceEndpoint),
defaultEndpoints: defaultEndpoints(),
}
for _, service := range services {
service := service
resolver.services[resolverKey(service.Name)] = service
}
return resolver
}
func (ar *awsResolver) EndpointFor(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
if s, ok := ar.services[resolverKey(service)]; ok {
logrus.Debugf("resolved AWS service %s (%s) to %q", service, region, s.URL)
signingRegion := ar.region
def, _ := endpoints.DefaultResolver().EndpointFor(service, region)
if len(def.SigningRegion) > 0 {
signingRegion = def.SigningRegion
}
return endpoints.ResolvedEndpoint{
URL: s.URL,
SigningRegion: signingRegion,
}, nil
}
if rv, ok := ar.defaultEndpoints[region]; ok {
if v, ok := rv[service]; ok {
return v, nil
}
}
return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
}
func resolverKey(service string) string {
return service
}
// this is a list of known default endpoints for specific regions that would
// otherwise require user to set the service overrides.
// it's a map of region => service => resolved endpoint
// this is only used when the user hasn't specified a override for the service in that region.
func defaultEndpoints() map[string]map[string]endpoints.ResolvedEndpoint {
return map[string]map[string]endpoints.ResolvedEndpoint{
endpoints.CnNorth1RegionID: {
"route53": {
URL: "https://route53.amazonaws.com.cn",
SigningRegion: endpoints.CnNorthwest1RegionID,
},
},
endpoints.CnNorthwest1RegionID: {
"route53": {
URL: "https://route53.amazonaws.com.cn",
SigningRegion: endpoints.CnNorthwest1RegionID,
},
},
}
}