Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ vendor/
.ignore*
local/
.deps/
.cache/
.cache/
*.env
14 changes: 9 additions & 5 deletions aws-cli-auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,33 @@ package main

import (
"context"
"log"
"os"
"os/signal"
"syscall"
"time"

"github.com/DevLabFoundry/aws-cli-auth/cmd"
"github.com/rs/zerolog"
)

func main() {
ctx, stop := signal.NotifyContext(context.Background(), []os.Signal{os.Interrupt, syscall.SIGTERM, os.Kill}...)
defer stop()
logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}).
Level(zerolog.ErrorLevel).
With().Timestamp().
Logger()

go func() {
<-ctx.Done()
stop()
// log.Printf("\x1b[31minterrupted: %s\x1b[0m", ctx.Err())
os.Exit(0)
logger.Fatal().Msgf("\x1b[31minterrupted: %s\x1b[0m", ctx.Err())
}()

c := cmd.New()
c := cmd.New(logger)
c.WithSubCommands(cmd.SubCommands()...)

if err := c.Execute(ctx); err != nil {
log.Fatalf("\x1b[31m%s\x1b[0m", err)
logger.Fatal().Msgf("\x1b[31m%s\x1b[0m", err)
}
}
10 changes: 7 additions & 3 deletions cmd/awscliauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (

"github.com/DevLabFoundry/aws-cli-auth/internal/credentialexchange"
"github.com/Ensono/eirctl/selfupdate"
"github.com/rs/zerolog"
"github.com/savioxavier/termlink"
"github.com/spf13/cobra"
)

Expand All @@ -23,6 +25,7 @@ type Root struct {
// ChannelErr io.Writer
// viperConf *viper.Viper
rootFlags *RootCmdFlags
logger zerolog.Logger
Datadir string
}

Expand All @@ -35,9 +38,10 @@ type RootCmdFlags struct {
CustomIniLocation string
}

func New() *Root {
func New(logger zerolog.Logger) *Root {
rf := &RootCmdFlags{}
r := &Root{
logger: logger,
rootFlags: rf,
Cmd: &cobra.Command{
Use: "aws-cli-auth",
Expand All @@ -57,8 +61,8 @@ Stores them under the $HOME/.aws/credentials file under a specified path or retu
r.Cmd.PersistentFlags().StringVarP(&rf.CfgSectionName, "cfg-section", "", "", "Config section name to use in the look up of the config ini file (~/.aws-cli-auth.ini) and in the AWS credentials file")
// When specifying store in profile the config section name must be provided
r.Cmd.MarkFlagsRequiredTogether("store-profile", "cfg-section")
r.Cmd.PersistentFlags().IntVarP(&rf.Duration, "max-duration", "d", 900, `Override default max session duration, in seconds, of the role session [900-43200].
NB: This cannot be higher than the 3600 as the API does not allow for AssumeRole for sessions longer than an hour`)
r.Cmd.PersistentFlags().IntVarP(&rf.Duration, "max-duration", "d", 900, fmt.Sprintf("Override default max session duration, in seconds, of the role session [900-43200].\nNB: This cannot be higher than the 3600 as the API does not allow for AssumeRole for sessions longer than an hour\nMore info on this and especially around role-chaining\nSee %s",
termlink.Link("AWS SDK Duration", "https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html#API_AssumeRole_RequestParameters")))
r.Cmd.PersistentFlags().BoolVarP(&rf.Verbose, "verbose", "v", false, "Verbose output")
r.Cmd.PersistentFlags().StringVarP(&rf.CustomIniLocation, "config-file", "c", "", "Specify the custom location of config file")

Expand Down
3 changes: 2 additions & 1 deletion cmd/awscliauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ import (
"github.com/DevLabFoundry/aws-cli-auth/cmd"
"github.com/DevLabFoundry/aws-cli-auth/internal/credentialexchange"
"github.com/DevLabFoundry/aws-cli-auth/internal/web"
"github.com/rs/zerolog"
)

func cmdHelperExecutor(t *testing.T, args []string) (stdOut *bytes.Buffer, errOut *bytes.Buffer, err error) {
t.Helper()
errOut = new(bytes.Buffer)
stdOut = new(bytes.Buffer)
c := cmd.New()
c := cmd.New(zerolog.New(io.Discard))
c.WithSubCommands(cmd.SubCommands()...)
c.Cmd.SetArgs(args)
c.Cmd.SetErr(errOut)
Expand Down
56 changes: 45 additions & 11 deletions cmd/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ import (
"github.com/DevLabFoundry/aws-cli-auth/internal/web"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sts"
validator "github.com/rezakhademix/govalidator/v2"
"github.com/rs/zerolog"
"github.com/spf13/cobra"
"gopkg.in/ini.v1"
)

var (
ErrUnableToCreateSession = errors.New("sts - cannot start a new session")
ErrValidationFailed = errors.New("missing values")
)

const (
Expand Down Expand Up @@ -64,17 +67,24 @@ func newSamlCmd(r *Root) {
if err != nil {
return err
}

if r.rootFlags.Verbose {
r.logger = r.logger.Level(zerolog.DebugLevel)
}
r.logger.Debug().Str("CustomIniLocation", r.rootFlags.CustomIniLocation).Msg("if empty using default ~/.aws-cli-auth.ini")
iniFile, err := samlInitConfig(r.rootFlags.CustomIniLocation)
if err != nil {
return err
}

r.logger.Debug().Msgf("iniFile: %+v", iniFile)

conf, err := credentialexchange.LoadCliConfig(iniFile, r.rootFlags.CfgSectionName)
if err != nil {
return err
}

r.logger.Debug().Str("section", r.rootFlags.CfgSectionName).Msgf("loaded section: %+v", conf)

if err := ConfigFromFlags(conf, r.rootFlags, flags, user.Username); err != nil {
return err
}
Expand All @@ -95,30 +105,35 @@ func newSamlCmd(r *Root) {
saveRole = allRoles[len(allRoles)-1]
}

r.logger.Debug().Str("saveRole", saveRole).
Str("SsoEndpoint", conf.SsoUserEndpoint).
Str("SsoCredFedEndpoint", conf.SsoCredFedEndpoint).
Msg("")

secretStore, err := credentialexchange.NewSecretStore(saveRole,
fmt.Sprintf("%s-%s", credentialexchange.SELF_NAME, credentialexchange.RoleKeyConverter(saveRole)),
os.TempDir(), user.Username)
if err != nil {
return err
}

// we want to remove any AWS_* env vars that could interfere with the default config
// for _, envVar := range []string{"AWS_PROFILE", "AWS_ACCESS_KEY_ID",
// "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"} {
// os.Unsetenv(envVar)
// }

awsConf, err := config.LoadDefaultConfig(ctx)
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return fmt.Errorf("failed to create session %s, %w", err, ErrUnableToCreateSession)
}

svc := sts.NewFromConfig(awsConf)
if cfg.Region == "" {
// cfg.en
return fmt.Errorf("unable to deduce AWS region, AWS_REGION, AWS_DEFAULT_REGION, ~/.aws/config default or profile level region must be set")
}

svc := sts.NewFromConfig(cfg)
cre := credentialexchange.New(r.logger, svc)
webConfig := web.NewWebConf(r.Datadir).
WithTimeout(flags.SamlTimeout).
WithCustomExecutable(conf.BaseConfig.BrowserExecutablePath)

return cmdutils.GetCredsWebUI(ctx, svc, secretStore, *conf, webConfig)
return cmdutils.GetCredsWebUI(ctx, cre, secretStore, *conf, webConfig)

},
PreRunE: func(cmd *cobra.Command, args []string) error {
Expand Down Expand Up @@ -165,7 +180,7 @@ If this flag is specified the --sso-role must also be specified.`)
// sc.cmd.MarkFlagsRequiredTogether("principal", "role")
// SSO flow for SAML
sc.cmd.MarkFlagsRequiredTogether("is-sso", "sso-role", "sso-region")
sc.cmd.PersistentFlags().Int32VarP(&flags.SamlTimeout, "saml-timeout", "", 120, "Timeout in seconds, before the operation of waiting for a response is cancelled via the chrome driver")
sc.cmd.PersistentFlags().Int32VarP(&flags.SamlTimeout, "saml-timeout", "", 120, "Timeout in seconds, before the operation of waiting for a response is cancelled via CDP (ChromeDeubgProto)")
// Add subcommand to root command
r.Cmd.AddCommand(sc.cmd)
}
Expand Down Expand Up @@ -219,5 +234,24 @@ func ConfigFromFlags(fileConfig *credentialexchange.CredentialConfig, rf *RootCm

fileConfig.BaseConfig = baseConf
fileConfig.Duration = d

return configValid(fileConfig)
}

func configValid(config *credentialexchange.CredentialConfig) error {
v := validator.New()
ssoVal := !config.IsSso
if config.IsSso {
ssoVal = len(config.SsoRole) > 0 && len(config.SsoRegion) > 0
}
v.RequiredString(config.ProviderUrl, "provider-url", "provider url must be specified").
// RequiredString(config.BaseConfig.Role, "role", "role must be provided").
RequiredString(config.PrincipalArn, "principal-arn", "principal ARN must be provided").
CustomRule(ssoVal, "is-sso", "sso-role must be specified when is-sso is set").
CustomRule((len(config.BaseConfig.Role) > 1 && len(config.SsoRole) < 1) || (len(config.BaseConfig.Role) < 1 && len(config.SsoRole) > 1), "role", "sso-role cannot be specified when role is also set")

if v.IsFailed() {
return fmt.Errorf("%w %#q", ErrValidationFailed, v.Errors())
}
return nil
}
77 changes: 75 additions & 2 deletions cmd/saml_test.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
package cmd_test

import (
"errors"
"testing"

"github.com/DevLabFoundry/aws-cli-auth/cmd"
"github.com/DevLabFoundry/aws-cli-auth/internal/credentialexchange"
"github.com/go-test/deep"
)

func Test_ConfigMerge(t *testing.T) {
func Test_ConfigMerge_succeeds(t *testing.T) {
conf := &credentialexchange.CredentialConfig{
BaseConfig: credentialexchange.BaseConfig{
BrowserExecutablePath: "/foo/path",
Role: "role1",
RoleChain: []string{"role-123"},
},
ProviderUrl: "https://my-idp.com/?app_id=testdd",
PrincipalArn: "aw:arn:....123",
ProviderUrl: "https://my-idp.com/?app_id=testdd",
}
if err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{Role: "role-overridden-from-flags"}, "me"); err != nil {
t.Fatal(err)
Expand All @@ -28,8 +30,79 @@ func Test_ConfigMerge(t *testing.T) {
RoleChain: []string{"role-123"},
Username: "me",
},
PrincipalArn: "aw:arn:....123",
}
if diff := deep.Equal(conf, want); len(diff) > 0 {
t.Errorf("diff: %v", diff)
}
}

func Test_ConfigMerge_fails_with_missing(t *testing.T) {
t.Run("provider not provided", func(t *testing.T) {

conf := &credentialexchange.CredentialConfig{
BaseConfig: credentialexchange.BaseConfig{
BrowserExecutablePath: "/foo/path",
Role: "",
RoleChain: []string{"role-123"},
},
ProviderUrl: "",
}
err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{Role: "role-overridden-from-flags"}, "me")
if !errors.Is(err, cmd.ErrValidationFailed) {
t.Error(err)
}
})
t.Run("role not provided", func(t *testing.T) {

conf := &credentialexchange.CredentialConfig{
BaseConfig: credentialexchange.BaseConfig{
BrowserExecutablePath: "/foo/path",
Role: "",
RoleChain: []string{"role-123"},
},
ProviderUrl: "https://my-idp.com/?app_id=testdd",
}
err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{}, "me")
if !errors.Is(err, cmd.ErrValidationFailed) {
t.Error(err)
}
})
t.Run("is-sso set but sso-role not set", func(t *testing.T) {

conf := &credentialexchange.CredentialConfig{
BaseConfig: credentialexchange.BaseConfig{
BrowserExecutablePath: "/foo/path",
Role: "",
RoleChain: []string{"role-123"},
},
PrincipalArn: "some-arn",
IsSso: true,
SsoRegion: "",
SsoRole: "foo:bar",
ProviderUrl: "https://my-idp.com/?app_id=testdd",
}
err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{}, "me")
if !errors.Is(err, cmd.ErrValidationFailed) {
t.Error(err)
}
})
t.Run("role and sso-role both provided", func(t *testing.T) {

conf := &credentialexchange.CredentialConfig{
BaseConfig: credentialexchange.BaseConfig{
BrowserExecutablePath: "/foo/path",
Role: "",
RoleChain: []string{"role-123"},
},
PrincipalArn: "some-arn",
SsoRegion: "foo",
SsoRole: "foo:bar",
ProviderUrl: "https://my-idp.com/?app_id=testdd",
}
err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{Role: "wrong-role"}, "me")
if !errors.Is(err, cmd.ErrValidationFailed) {
t.Error(err)
}
})
}
7 changes: 4 additions & 3 deletions cmd/specific.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@ Returns the same JSON object as the call to the AWS CLI for any of the sts Assum
svc := sts.NewFromConfig(cfg)

user, err := user.Current()

if err != nil {
return err
}

cre := credentialexchange.New(r.logger, svc)

if flags.method != "" {
switch flags.method {
case "WEB_ID":
awsCreds, err = credentialexchange.LoginAwsWebToken(ctx, user.Name, svc)
awsCreds, err = cre.LoginAwsWebToken(ctx, user.Name)
if err != nil {
return err
}
Expand All @@ -69,7 +70,7 @@ Returns the same JSON object as the call to the AWS CLI for any of the sts Assum
Duration: r.rootFlags.Duration,
}

awsCreds, err = credentialexchange.AssumeRoleInChain(ctx, awsCreds, svc, config.BaseConfig.Username, config.BaseConfig.RoleChain, conf)
awsCreds, err = cre.AssumeRoleInChain(ctx, awsCreds, config.BaseConfig.Username, config.BaseConfig.RoleChain, conf)
if err != nil {
return err
}
Expand Down
Loading
Loading