From b55301a745dcb54bcc5a1161dc33ac3cbf75d50c Mon Sep 17 00:00:00 2001 From: dnitsch Date: Wed, 17 Dec 2025 20:39:35 +0000 Subject: [PATCH 1/6] fix: add validator to the final config struct --- cmd/saml.go | 17 +++++++++++++++ cmd/saml_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++-- go.mod | 3 ++- go.sum | 12 ++-------- 4 files changed, 76 insertions(+), 13 deletions(-) diff --git a/cmd/saml.go b/cmd/saml.go index e58c17e..6d7328b 100755 --- a/cmd/saml.go +++ b/cmd/saml.go @@ -13,12 +13,14 @@ import ( "github.com/DevLabFoundry/aws-cli-auth/internal/web" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" + validator "github.com/rezakhademix/govalidator/v2" "github.com/spf13/cobra" "gopkg.in/ini.v1" ) var ( ErrUnableToCreateSession = errors.New("sts - cannot start a new session") + ErrValidationFailed = errors.New("missing values") ) const ( @@ -219,5 +221,20 @@ func ConfigFromFlags(fileConfig *credentialexchange.CredentialConfig, rf *RootCm fileConfig.BaseConfig = baseConf fileConfig.Duration = d + + return configValid(fileConfig) +} + +func configValid(config *credentialexchange.CredentialConfig) error { + v := validator.New() + + v.RequiredString(config.ProviderUrl, "provider-url", "provider url must be specified"). + RequiredString(config.BaseConfig.Role, "role", "role must be provided"). + RequiredString(config.PrincipalArn, "principal-arn", "principal ARN must be provided"). + CustomRule(!(len(config.BaseConfig.Role) > 1 && len(config.SsoRole) > 1), "sso-role", "sso-role cannot be specified when role is also set") + + if v.IsFailed() { + return fmt.Errorf("%w %#q", ErrValidationFailed, v.Errors()) + } return nil } diff --git a/cmd/saml_test.go b/cmd/saml_test.go index bd7735f..26895d4 100644 --- a/cmd/saml_test.go +++ b/cmd/saml_test.go @@ -1,6 +1,7 @@ package cmd_test import ( + "errors" "testing" "github.com/DevLabFoundry/aws-cli-auth/cmd" @@ -8,14 +9,15 @@ import ( "github.com/go-test/deep" ) -func Test_ConfigMerge(t *testing.T) { +func Test_ConfigMerge_succeeds(t *testing.T) { conf := &credentialexchange.CredentialConfig{ BaseConfig: credentialexchange.BaseConfig{ BrowserExecutablePath: "/foo/path", Role: "role1", RoleChain: []string{"role-123"}, }, - ProviderUrl: "https://my-idp.com/?app_id=testdd", + PrincipalArn: "aw:arn:....123", + ProviderUrl: "https://my-idp.com/?app_id=testdd", } if err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{Role: "role-overridden-from-flags"}, "me"); err != nil { t.Fatal(err) @@ -28,8 +30,59 @@ func Test_ConfigMerge(t *testing.T) { RoleChain: []string{"role-123"}, Username: "me", }, + PrincipalArn: "aw:arn:....123", } if diff := deep.Equal(conf, want); len(diff) > 0 { t.Errorf("diff: %v", diff) } } + +func Test_ConfigMerge_fails_with_missing(t *testing.T) { + t.Run("provider not provided", func(t *testing.T) { + + conf := &credentialexchange.CredentialConfig{ + BaseConfig: credentialexchange.BaseConfig{ + BrowserExecutablePath: "/foo/path", + Role: "", + RoleChain: []string{"role-123"}, + }, + ProviderUrl: "", + } + err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{Role: "role-overridden-from-flags"}, "me") + if !errors.Is(err, cmd.ErrValidationFailed) { + t.Error(err) + } + }) + t.Run("role not provided", func(t *testing.T) { + + conf := &credentialexchange.CredentialConfig{ + BaseConfig: credentialexchange.BaseConfig{ + BrowserExecutablePath: "/foo/path", + Role: "", + RoleChain: []string{"role-123"}, + }, + ProviderUrl: "https://my-idp.com/?app_id=testdd", + } + err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{}, "me") + if !errors.Is(err, cmd.ErrValidationFailed) { + t.Error(err) + } + }) + t.Run("role and sso-role provided", func(t *testing.T) { + + conf := &credentialexchange.CredentialConfig{ + BaseConfig: credentialexchange.BaseConfig{ + BrowserExecutablePath: "/foo/path", + Role: "", + RoleChain: []string{"role-123"}, + }, + SsoRegion: "foo", + SsoRole: "foo:bar", + ProviderUrl: "https://my-idp.com/?app_id=testdd", + } + err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{}, "me") + if !errors.Is(err, cmd.ErrValidationFailed) { + t.Error(err) + } + }) +} diff --git a/go.mod b/go.mod index 77132d8..104f769 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/DevLabFoundry/aws-cli-auth -go 1.25.4 +go 1.25.5 require ( github.com/aws/aws-sdk-go-v2 v1.39.6 @@ -8,6 +8,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.40.2 github.com/aws/smithy-go v1.23.2 github.com/go-rod/rod v0.116.2 + github.com/rezakhademix/govalidator/v2 v2.1.2 github.com/spf13/cobra v1.10.1 github.com/werf/lockgate v0.1.1 github.com/zalando/go-keyring v0.2.6 diff --git a/go.sum b/go.sum index 30bbfee..4dd7409 100644 --- a/go.sum +++ b/go.sum @@ -6,12 +6,8 @@ github.com/Ensono/eirctl v0.9.6 h1:G6S0ZJ2VtedGW2/nn8sbMQnNbLVXgjwNJnuLEHUjJRc= github.com/Ensono/eirctl v0.9.6/go.mod h1:pxX1iE+guf8Lyvs98FkNnMKqyTtHaLrJgB3f4foEROk= github.com/aws/aws-sdk-go-v2 v1.39.6 h1:2JrPCVgWJm7bm83BDwY5z8ietmeJUbh3O2ACnn+Xsqk= github.com/aws/aws-sdk-go-v2 v1.39.6/go.mod h1:c9pm7VwuW0UPxAEYGyTmyurVcNrbF6Rt/wixFqDhcjE= -github.com/aws/aws-sdk-go-v2/config v1.31.19 h1:qdUtOw4JhZr2YcKO3g0ho/IcFXfXrrb8xlX05Y6EvSw= -github.com/aws/aws-sdk-go-v2/config v1.31.19/go.mod h1:tMJ8bur01t8eEm0atLadkIIFA154OJ4JCKZeQ+o+R7k= github.com/aws/aws-sdk-go-v2/config v1.31.20 h1:/jWF4Wu90EhKCgjTdy1DGxcbcbNrjfBHvksEL79tfQc= github.com/aws/aws-sdk-go-v2/config v1.31.20/go.mod h1:95Hh1Tc5VYKL9NJ7tAkDcqeKt+MCXQB1hQZaRdJIZE0= -github.com/aws/aws-sdk-go-v2/credentials v1.18.23 h1:IQILcxVgMO2BVLaJ2aAv21dKWvE1MduNrbvuK43XL2Q= -github.com/aws/aws-sdk-go-v2/credentials v1.18.23/go.mod h1:JRodHszhVdh5TPUknxDzJzrMiznG+M+FfR3WSWKgCI8= github.com/aws/aws-sdk-go-v2/credentials v1.18.24 h1:iJ2FmPT35EaIB0+kMa6TnQ+PwG5A1prEdAw+PsMzfHg= github.com/aws/aws-sdk-go-v2/credentials v1.18.24/go.mod h1:U91+DrfjAiXPDEGYhh/x29o4p0qHX5HDqG7y5VViv64= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13 h1:T1brd5dR3/fzNFAQch/iBKeX07/ffu/cLu+q+RuzEWk= @@ -26,16 +22,10 @@ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 h1:x2Ibm/A github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3/go.mod h1:IW1jwyrQgMdhisceG8fQLmQIydcT/jWY21rFhzgaKwo= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 h1:kDqdFvMY4AtKoACfzIGD8A0+hbT41KTKF//gq7jITfM= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13/go.mod h1:lmKuogqSU3HzQCwZ9ZtcqOc5XGMqtDK7OIc2+DxiUEg= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.2 h1:/p6MxkbQoCzaGQT3WO0JwG0FlQyG9RD8VmdmoKc5xqU= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.2/go.mod h1:fKvyjJcz63iL/ftA6RaM8sRCtN4r4zl4tjL3qw5ec7k= github.com/aws/aws-sdk-go-v2/service/sso v1.30.3 h1:NjShtS1t8r5LUfFVtFeI8xLAHQNTa7UI0VawXlrBMFQ= github.com/aws/aws-sdk-go-v2/service/sso v1.30.3/go.mod h1:fKvyjJcz63iL/ftA6RaM8sRCtN4r4zl4tjL3qw5ec7k= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.6 h1:0dES42T2dhICCbVB3JSTTn7+Bz93wfJEK1b7jksZIyQ= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.6/go.mod h1:klO+ejMvYsB4QATfEOIXk8WAEwN4N0aBfJpvC+5SZBo= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.7 h1:gTsnx0xXNQ6SBbymoDvcoRHL+q4l/dAFsQuKfDWSaGc= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.7/go.mod h1:klO+ejMvYsB4QATfEOIXk8WAEwN4N0aBfJpvC+5SZBo= -github.com/aws/aws-sdk-go-v2/service/sts v1.40.1 h1:5sbIM57lHLaEaNWdIx23JH30LNBsSDkjN/QXGcRLAFc= -github.com/aws/aws-sdk-go-v2/service/sts v1.40.1/go.mod h1:E19xDjpzPZC7LS2knI9E6BaRFDK43Eul7vd6rSq2HWk= github.com/aws/aws-sdk-go-v2/service/sts v1.40.2 h1:HK5ON3KmQV2HcAunnx4sKLB9aPf3gKGwVAf7xnx0QT0= github.com/aws/aws-sdk-go-v2/service/sts v1.40.2/go.mod h1:E19xDjpzPZC7LS2knI9E6BaRFDK43Eul7vd6rSq2HWk= github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM= @@ -85,6 +75,8 @@ github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2Em github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rezakhademix/govalidator/v2 v2.1.2 h1:qqCIkWC6sWr8zeW9zCkYEJxbZMt/Dn1ASXkGIQe3rDI= +github.com/rezakhademix/govalidator/v2 v2.1.2/go.mod h1:be7JrYM3STiL5jYt1WrQN5ArR8xTov/DvWJ9yXtULj8= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= From 621b0dced5d08dab974dba44564a8f66c39b8fb7 Mon Sep 17 00:00:00 2001 From: dnitsch Date: Wed, 17 Dec 2025 20:57:44 +0000 Subject: [PATCH 2/6] fix: De Morgans law fix in lint --- cmd/saml.go | 8 ++++++-- cmd/saml_test.go | 29 ++++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/cmd/saml.go b/cmd/saml.go index 6d7328b..bafd71b 100755 --- a/cmd/saml.go +++ b/cmd/saml.go @@ -227,11 +227,15 @@ func ConfigFromFlags(fileConfig *credentialexchange.CredentialConfig, rf *RootCm func configValid(config *credentialexchange.CredentialConfig) error { v := validator.New() - + ssoVal := !config.IsSso + if config.IsSso { + ssoVal = len(config.SsoRole) > 0 && len(config.SsoRegion) > 0 + } v.RequiredString(config.ProviderUrl, "provider-url", "provider url must be specified"). RequiredString(config.BaseConfig.Role, "role", "role must be provided"). RequiredString(config.PrincipalArn, "principal-arn", "principal ARN must be provided"). - CustomRule(!(len(config.BaseConfig.Role) > 1 && len(config.SsoRole) > 1), "sso-role", "sso-role cannot be specified when role is also set") + CustomRule(ssoVal, "is-sso", "sso-role must be specified when is-sso is set"). + CustomRule((len(config.BaseConfig.Role) > 1 && len(config.SsoRole) < 1) || (len(config.BaseConfig.Role) < 1 && len(config.SsoRole) > 1), "sso-role", "sso-role cannot be specified when role is also set") if v.IsFailed() { return fmt.Errorf("%w %#q", ErrValidationFailed, v.Errors()) diff --git a/cmd/saml_test.go b/cmd/saml_test.go index 26895d4..d8e8bf8 100644 --- a/cmd/saml_test.go +++ b/cmd/saml_test.go @@ -68,7 +68,7 @@ func Test_ConfigMerge_fails_with_missing(t *testing.T) { t.Error(err) } }) - t.Run("role and sso-role provided", func(t *testing.T) { + t.Run("is-sso set but sso-role not set", func(t *testing.T) { conf := &credentialexchange.CredentialConfig{ BaseConfig: credentialexchange.BaseConfig{ @@ -76,11 +76,30 @@ func Test_ConfigMerge_fails_with_missing(t *testing.T) { Role: "", RoleChain: []string{"role-123"}, }, - SsoRegion: "foo", - SsoRole: "foo:bar", - ProviderUrl: "https://my-idp.com/?app_id=testdd", + PrincipalArn: "some-arn", + SsoRegion: "foo", + SsoRole: "foo:bar", + ProviderUrl: "https://my-idp.com/?app_id=testdd", } - err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{}, "me") + err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{Role: "wrong-role"}, "me") + if !errors.Is(err, cmd.ErrValidationFailed) { + t.Error(err) + } + }) + t.Run("role and sso-role both provided", func(t *testing.T) { + + conf := &credentialexchange.CredentialConfig{ + BaseConfig: credentialexchange.BaseConfig{ + BrowserExecutablePath: "/foo/path", + Role: "", + RoleChain: []string{"role-123"}, + }, + PrincipalArn: "some-arn", + SsoRegion: "foo", + SsoRole: "foo:bar", + ProviderUrl: "https://my-idp.com/?app_id=testdd", + } + err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{Role: "wrong-role"}, "me") if !errors.Is(err, cmd.ErrValidationFailed) { t.Error(err) } From b49d0c7aa0fb63b5439e862dabb8642c2d4d6741 Mon Sep 17 00:00:00 2001 From: dnitsch Date: Wed, 17 Dec 2025 21:05:13 +0000 Subject: [PATCH 3/6] fix: add more tests to sso edge case --- cmd/saml.go | 4 ++-- cmd/saml_test.go | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/cmd/saml.go b/cmd/saml.go index bafd71b..d6acc34 100755 --- a/cmd/saml.go +++ b/cmd/saml.go @@ -232,10 +232,10 @@ func configValid(config *credentialexchange.CredentialConfig) error { ssoVal = len(config.SsoRole) > 0 && len(config.SsoRegion) > 0 } v.RequiredString(config.ProviderUrl, "provider-url", "provider url must be specified"). - RequiredString(config.BaseConfig.Role, "role", "role must be provided"). + // RequiredString(config.BaseConfig.Role, "role", "role must be provided"). RequiredString(config.PrincipalArn, "principal-arn", "principal ARN must be provided"). CustomRule(ssoVal, "is-sso", "sso-role must be specified when is-sso is set"). - CustomRule((len(config.BaseConfig.Role) > 1 && len(config.SsoRole) < 1) || (len(config.BaseConfig.Role) < 1 && len(config.SsoRole) > 1), "sso-role", "sso-role cannot be specified when role is also set") + CustomRule((len(config.BaseConfig.Role) > 1 && len(config.SsoRole) < 1) || (len(config.BaseConfig.Role) < 1 && len(config.SsoRole) > 1), "role", "sso-role cannot be specified when role is also set") if v.IsFailed() { return fmt.Errorf("%w %#q", ErrValidationFailed, v.Errors()) diff --git a/cmd/saml_test.go b/cmd/saml_test.go index d8e8bf8..275e693 100644 --- a/cmd/saml_test.go +++ b/cmd/saml_test.go @@ -77,11 +77,12 @@ func Test_ConfigMerge_fails_with_missing(t *testing.T) { RoleChain: []string{"role-123"}, }, PrincipalArn: "some-arn", - SsoRegion: "foo", + IsSso: true, + SsoRegion: "", SsoRole: "foo:bar", ProviderUrl: "https://my-idp.com/?app_id=testdd", } - err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{Role: "wrong-role"}, "me") + err := cmd.ConfigFromFlags(conf, &cmd.RootCmdFlags{}, &cmd.SamlCmdFlags{}, "me") if !errors.Is(err, cmd.ErrValidationFailed) { t.Error(err) } From e7c0fc4c6a9bafc0af50a69091d262eec47a80d7 Mon Sep 17 00:00:00 2001 From: dnitsch Date: Wed, 7 Jan 2026 20:46:06 +0000 Subject: [PATCH 4/6] fix: add some logging --- .gitignore | 3 +- aws-cli-auth.go | 14 +- cmd/awscliauth.go | 5 +- cmd/awscliauth_test.go | 3 +- cmd/saml.go | 31 +++-- eirctl.yaml | 14 +- go.mod | 3 + go.sum | 12 ++ internal/credentialexchange/config.go | 46 +++++++ .../credentialexchange/credentialexchange.go | 121 +++++++----------- 10 files changed, 155 insertions(+), 97 deletions(-) diff --git a/.gitignore b/.gitignore index ed3a862..e64afd6 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,5 @@ vendor/ .ignore* local/ .deps/ -.cache/ \ No newline at end of file +.cache/ +*.env diff --git a/aws-cli-auth.go b/aws-cli-auth.go index fa6577e..54c6fc9 100755 --- a/aws-cli-auth.go +++ b/aws-cli-auth.go @@ -2,29 +2,33 @@ package main import ( "context" - "log" "os" "os/signal" "syscall" + "time" "github.com/DevLabFoundry/aws-cli-auth/cmd" + "github.com/rs/zerolog" ) func main() { ctx, stop := signal.NotifyContext(context.Background(), []os.Signal{os.Interrupt, syscall.SIGTERM, os.Kill}...) defer stop() + logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: time.RFC3339}). + Level(zerolog.ErrorLevel). + With().Timestamp(). + Logger() go func() { <-ctx.Done() stop() - // log.Printf("\x1b[31minterrupted: %s\x1b[0m", ctx.Err()) - os.Exit(0) + logger.Fatal().Msgf("\x1b[31minterrupted: %s\x1b[0m", ctx.Err()) }() - c := cmd.New() + c := cmd.New(logger) c.WithSubCommands(cmd.SubCommands()...) if err := c.Execute(ctx); err != nil { - log.Fatalf("\x1b[31m%s\x1b[0m", err) + logger.Fatal().Msgf("\x1b[31m%s\x1b[0m", err) } } diff --git a/cmd/awscliauth.go b/cmd/awscliauth.go index 39f4fcc..0e46063 100755 --- a/cmd/awscliauth.go +++ b/cmd/awscliauth.go @@ -9,6 +9,7 @@ import ( "github.com/DevLabFoundry/aws-cli-auth/internal/credentialexchange" "github.com/Ensono/eirctl/selfupdate" + "github.com/rs/zerolog" "github.com/spf13/cobra" ) @@ -23,6 +24,7 @@ type Root struct { // ChannelErr io.Writer // viperConf *viper.Viper rootFlags *RootCmdFlags + logger zerolog.Logger Datadir string } @@ -35,9 +37,10 @@ type RootCmdFlags struct { CustomIniLocation string } -func New() *Root { +func New(logger zerolog.Logger) *Root { rf := &RootCmdFlags{} r := &Root{ + logger: logger, rootFlags: rf, Cmd: &cobra.Command{ Use: "aws-cli-auth", diff --git a/cmd/awscliauth_test.go b/cmd/awscliauth_test.go index 14419dd..6145fd9 100644 --- a/cmd/awscliauth_test.go +++ b/cmd/awscliauth_test.go @@ -12,13 +12,14 @@ import ( "github.com/DevLabFoundry/aws-cli-auth/cmd" "github.com/DevLabFoundry/aws-cli-auth/internal/credentialexchange" "github.com/DevLabFoundry/aws-cli-auth/internal/web" + "github.com/rs/zerolog" ) func cmdHelperExecutor(t *testing.T, args []string) (stdOut *bytes.Buffer, errOut *bytes.Buffer, err error) { t.Helper() errOut = new(bytes.Buffer) stdOut = new(bytes.Buffer) - c := cmd.New() + c := cmd.New(zerolog.New(io.Discard)) c.WithSubCommands(cmd.SubCommands()...) c.Cmd.SetArgs(args) c.Cmd.SetErr(errOut) diff --git a/cmd/saml.go b/cmd/saml.go index d6acc34..f20bdf1 100755 --- a/cmd/saml.go +++ b/cmd/saml.go @@ -14,6 +14,7 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sts" validator "github.com/rezakhademix/govalidator/v2" + "github.com/rs/zerolog" "github.com/spf13/cobra" "gopkg.in/ini.v1" ) @@ -66,17 +67,24 @@ func newSamlCmd(r *Root) { if err != nil { return err } - + if r.rootFlags.Verbose { + r.logger = r.logger.Level(zerolog.DebugLevel) + } + r.logger.Debug().Str("CustomIniLocation", r.rootFlags.CustomIniLocation).Msg("if empty using default ~/.aws-cli-auth.ini") iniFile, err := samlInitConfig(r.rootFlags.CustomIniLocation) if err != nil { return err } + r.logger.Debug().Msgf("iniFile: %+v", iniFile) + conf, err := credentialexchange.LoadCliConfig(iniFile, r.rootFlags.CfgSectionName) if err != nil { return err } + r.logger.Debug().Str("section", r.rootFlags.CfgSectionName).Msgf("loaded section: %+v", conf) + if err := ConfigFromFlags(conf, r.rootFlags, flags, user.Username); err != nil { return err } @@ -97,6 +105,11 @@ func newSamlCmd(r *Root) { saveRole = allRoles[len(allRoles)-1] } + r.logger.Debug().Str("saveRole", saveRole). + Str("SsoEndpoint", conf.SsoUserEndpoint). + Str("SsoCredFedEndpoint", conf.SsoCredFedEndpoint). + Msg("") + secretStore, err := credentialexchange.NewSecretStore(saveRole, fmt.Sprintf("%s-%s", credentialexchange.SELF_NAME, credentialexchange.RoleKeyConverter(saveRole)), os.TempDir(), user.Username) @@ -104,18 +117,16 @@ func newSamlCmd(r *Root) { return err } - // we want to remove any AWS_* env vars that could interfere with the default config - // for _, envVar := range []string{"AWS_PROFILE", "AWS_ACCESS_KEY_ID", - // "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"} { - // os.Unsetenv(envVar) - // } - - awsConf, err := config.LoadDefaultConfig(ctx) + cfg, err := config.LoadDefaultConfig(ctx) if err != nil { return fmt.Errorf("failed to create session %s, %w", err, ErrUnableToCreateSession) } - svc := sts.NewFromConfig(awsConf) + if cfg.Region == "" { + return fmt.Errorf("unable to deduce AWS region, AWS_REGION, AWS_DEFAULT_REGION, ~/.aws/config default or profile level region must be set") + } + + svc := sts.NewFromConfig(cfg) webConfig := web.NewWebConf(r.Datadir). WithTimeout(flags.SamlTimeout). WithCustomExecutable(conf.BaseConfig.BrowserExecutablePath) @@ -167,7 +178,7 @@ If this flag is specified the --sso-role must also be specified.`) // sc.cmd.MarkFlagsRequiredTogether("principal", "role") // SSO flow for SAML sc.cmd.MarkFlagsRequiredTogether("is-sso", "sso-role", "sso-region") - sc.cmd.PersistentFlags().Int32VarP(&flags.SamlTimeout, "saml-timeout", "", 120, "Timeout in seconds, before the operation of waiting for a response is cancelled via the chrome driver") + sc.cmd.PersistentFlags().Int32VarP(&flags.SamlTimeout, "saml-timeout", "", 120, "Timeout in seconds, before the operation of waiting for a response is cancelled via CDP (ChromeDeubgProto)") // Add subcommand to root command r.Cmd.AddCommand(sc.cmd) } diff --git a/eirctl.yaml b/eirctl.yaml index 678352e..a08a26e 100644 --- a/eirctl.yaml +++ b/eirctl.yaml @@ -1,5 +1,5 @@ import: - - https://raw.githubusercontent.com/Ensono/eirctl/e71dd9d66293e27e70fd0620e63a6d627579c060/shared/build/go/eirctl.yaml + - https://raw.githubusercontent.com/Ensono/eirctl/refs/tags/v0.9.7/shared/build/go/eirctl.yaml contexts: unit:test: @@ -12,7 +12,7 @@ contexts: - GO pipelines: - build: + build: - task: build:unix - task: build:win depends_on: build:unix @@ -20,6 +20,8 @@ pipelines: unit:test:run: - task: unit:test:prereqs - task: unit:test + env: + ROOT_PKG_NAME: github.com/DevLabFoundry depends_on: unit:test:prereqs bin:release: @@ -34,7 +36,7 @@ pipelines: tasks: tag: - command: + command: - | git tag -a ${VERSION} -m "ci tag release" ${REVISION} git push origin ${VERSION} @@ -48,6 +50,7 @@ tasks: description: | Unit test runner needs a bit of extra care in this case to ensure we have all the dependencies command: | + unset GOTOOLCHAIN export GOPATH=$PWD/.deps GOBIN=$PWD/.deps/bin CGO_ENABLED=1 go test ./... -v -coverpkg=github.com/DevLabFoundry/... -race -mod=readonly -timeout=1m -shuffle=on -buildvcs=false -coverprofile=.coverage/out -count=1 -run=$GO_TEST_RUN_ARGS | tee .coverage/test.out cat .coverage/test.out | .deps/bin/go-junit-report > .coverage/report-junit.xml @@ -56,7 +59,7 @@ tasks: unit:test:prereqs: description: Installs coverage and junit tools context: unit:test - command: + command: - | mkdir -p .coverage export GOPATH="${PWD}/.deps" GOBIN="${PWD}/.deps/bin" @@ -65,13 +68,14 @@ tasks: go install github.com/AlekSi/gocov-xml@v1.0.0 clean:dir: - command: + command: - | rm -rf dist/ build:win: context: go1x description: Builds Go binary + reset_context: true command: - | mkdir -p .deps diff --git a/go.mod b/go.mod index 104f769..9cd8fd2 100644 --- a/go.mod +++ b/go.mod @@ -16,8 +16,11 @@ require ( ) require ( + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/rs/zerolog v1.34.0 // indirect github.com/schollz/progressbar/v3 v3.18.0 // indirect golang.org/x/term v0.37.0 // indirect ) diff --git a/go.sum b/go.sum index 4dd7409..03fe12e 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,7 @@ github.com/chengxilo/virtualterm v1.0.4 h1:Z6IpERbRVlfB8WkOmtbHiDbBANU7cimRIof7m github.com/chengxilo/virtualterm v1.0.4/go.mod h1:DyxxBZz/x1iqJjFxTFcr6/x+jSpqN0iwWCOK1q10rlY= github.com/clipperhouse/uax29/v2 v2.2.0 h1:ChwIKnQN3kcZteTXMgb1wztSgaU+ZemkgWdohwgs8tY= github.com/clipperhouse/uax29/v2 v2.2.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/danieljoos/wincred v1.2.3 h1:v7dZC2x32Ut3nEfRH+vhoZGvN72+dQ/snVXo/vMFLdQ= github.com/danieljoos/wincred v1.2.3/go.mod h1:6qqX0WNrS4RzPZ1tnroDzq9kY3fu1KwE7MRLQK4X0bs= @@ -51,6 +52,7 @@ github.com/go-rod/rod v0.116.2 h1:A5t2Ky2A+5eD/ZJQr1EfsQSe5rms5Xof/qj296e+ZqA= github.com/go-rod/rod v0.116.2/go.mod h1:H+CMO9SCNc2TJ2WfrG+pKhITz57uGNYU43qYHh438Mg= github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= @@ -65,20 +67,27 @@ github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcI github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rezakhademix/govalidator/v2 v2.1.2 h1:qqCIkWC6sWr8zeW9zCkYEJxbZMt/Dn1ASXkGIQe3rDI= github.com/rezakhademix/govalidator/v2 v2.1.2/go.mod h1:be7JrYM3STiL5jYt1WrQN5ArR8xTov/DvWJ9yXtULj8= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/schollz/progressbar/v3 v3.18.0 h1:uXdoHABRFmNIjUfte/Ex7WtuyVslrw2wVPQmCN62HpA= github.com/schollz/progressbar/v3 v3.18.0/go.mod h1:IsO3lpbaGuzh8zIMzgY3+J8l4C8GjO0Y9S69eFvNsec= @@ -117,6 +126,9 @@ github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8u github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU= golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= diff --git a/internal/credentialexchange/config.go b/internal/credentialexchange/config.go index 430a2f4..93bcb55 100644 --- a/internal/credentialexchange/config.go +++ b/internal/credentialexchange/config.go @@ -1,5 +1,11 @@ package credentialexchange +import ( + "encoding/json" + "fmt" + "time" +) + const ( SELF_NAME = "aws-cli-auth" WEB_ID_TOKEN_VAR = "AWS_WEB_IDENTITY_TOKEN_FILE" @@ -29,3 +35,43 @@ type CredentialConfig struct { SsoUserEndpoint string `ini:"is-sso-endpoint"` SsoCredFedEndpoint string } + +// AWSRole aws role attributes +type AWSRoleConfig struct { + RoleARN string + PrincipalARN string + Name string +} + +// AWSCredentials is a representation of the returned credential +type AWSCredentials struct { + Version int + AWSAccessKey string `json:"AccessKeyId"` + AWSSecretKey string `json:"SecretAccessKey"` + AWSSessionToken string `json:"SessionToken"` + PrincipalARN string `json:"-"` + Expires time.Time `json:"Expiration"` +} + +// roleCreds can be encapsulated in this function +// never used outside of this scope for now +type roleCreds struct { + RoleCreds struct { + AccessKey string `json:"accessKeyId"` + SecretKey string `json:"secretAccessKey"` + SessionToken string `json:"sessionToken"` + Expiration int64 `json:"expiration"` + } `json:"roleCredentials"` +} + +func (a *AWSCredentials) FromRoleCredString(cred string) (*AWSCredentials, error) { + rc := &roleCreds{} + if err := json.Unmarshal([]byte(cred), rc); err != nil { + return nil, fmt.Errorf("%s, %w", err, ErrUnmarshalCred) + } + a.AWSAccessKey = rc.RoleCreds.AccessKey + a.AWSSecretKey = rc.RoleCreds.SecretKey + a.AWSSessionToken = rc.RoleCreds.SessionToken + a.Expires = time.UnixMilli(rc.RoleCreds.Expiration) + return a, nil +} diff --git a/internal/credentialexchange/credentialexchange.go b/internal/credentialexchange/credentialexchange.go index d9dc7f0..5588d56 100755 --- a/internal/credentialexchange/credentialexchange.go +++ b/internal/credentialexchange/credentialexchange.go @@ -2,7 +2,6 @@ package credentialexchange import ( "context" - "encoding/json" "errors" "fmt" "os" @@ -21,51 +20,30 @@ var ( ErrUnmarshalCred = errors.New("unable to unmarshal credential from string") ) -// AWSRole aws role attributes -type AWSRoleConfig struct { - RoleARN string - PrincipalARN string - Name string -} - -// AWSCredentials is a representation of the returned credential -type AWSCredentials struct { - Version int - AWSAccessKey string `json:"AccessKeyId"` - AWSSecretKey string `json:"SecretAccessKey"` - AWSSessionToken string `json:"SessionToken"` - PrincipalARN string `json:"-"` - Expires time.Time `json:"Expiration"` -} - -func (a *AWSCredentials) FromRoleCredString(cred string) (*AWSCredentials, error) { - // RoleCreds can be encapsulated in this function - // never used outside of this scope for now - type RoleCreds struct { - RoleCreds struct { - AccessKey string `json:"accessKeyId"` - SecretKey string `json:"secretAccessKey"` - SessionToken string `json:"sessionToken"` - Expiration int64 `json:"expiration"` - } `json:"roleCredentials"` - } - rc := &RoleCreds{} - if err := json.Unmarshal([]byte(cred), rc); err != nil { - return nil, fmt.Errorf("%s, %w", err, ErrUnmarshalCred) - } - a.AWSAccessKey = rc.RoleCreds.AccessKey - a.AWSSecretKey = rc.RoleCreds.SecretKey - a.AWSSessionToken = rc.RoleCreds.SessionToken - a.Expires = time.UnixMilli(rc.RoleCreds.Expiration) - return a, nil -} - type AuthSamlApi interface { AssumeRoleWithSAML(ctx context.Context, params *sts.AssumeRoleWithSAMLInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithSAMLOutput, error) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) } +type authWebTokenApi interface { + AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) +} + +// type CredentialExchange struct { +// logger zerolog.Logger +// samlSvc AuthSamlApi +// specificSvc authWebTokenApi +// } + +// func New(logger zerolog.Logger, samlSvc AuthSamlApi, specificSvc authWebTokenApi) *CredentialExchange { +// return &CredentialExchange{ +// logger: logger, +// samlSvc: samlSvc, +// specificSvc: specificSvc, +// } +// } + // LoginStsSaml exchanges saml response for STS creds func LoginStsSaml(ctx context.Context, samlResponse string, role AWSRole, svc AuthSamlApi) (*AWSCredentials, error) { @@ -76,9 +54,13 @@ func LoginStsSaml(ctx context.Context, samlResponse string, role AWSRole, svc Au DurationSeconds: aws.Int32(int32(role.Duration)), } + // unsetting the AWS_PROFILE here as we want to assume using samlResp credentials + // + // if profile is set the credential provider fails to cascade back to `[default]` section in ~/.aws/config + // os.Unsetenv("AWS_PROFILE") resp, err := svc.AssumeRoleWithSAML(ctx, params) if err != nil { - return nil, fmt.Errorf("failed to retrieve STS credentials using SAML: %s, %w", err.Error(), ErrUnableAssume) + return nil, fmt.Errorf("%w, failed to retrieve STS credentials using SAML: %s", ErrUnableAssume, err.Error()) } return &AWSCredentials{ @@ -90,15 +72,6 @@ func LoginStsSaml(ctx context.Context, samlResponse string, role AWSRole, svc Au }, nil } -type credsProvider struct { - accessKey, secretKey, sessionToken string - expiry time.Time -} - -func (c *credsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { - return aws.Credentials{AccessKeyID: c.accessKey, SecretAccessKey: c.secretKey, SessionToken: c.sessionToken, CanExpire: true, Expires: c.expiry}, nil -} - // IsValid checks current credentials and // returns them if they are still valid // if reloadTimeBefore is less than time left on the creds @@ -109,11 +82,6 @@ func IsValid(ctx context.Context, currentCreds *AWSCredentials, reloadBeforeTime } if _, err := svc.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}, func(o *sts.Options) { - // set the default region for the - // if o.EndpointOptions.GetResolvedRegion() == "" { - // // cannot determine - // o.BaseEndpoint = aws.String("https://sts.amazonaws.com") - // } o.Credentials = &credsProvider{currentCreds.AWSAccessKey, currentCreds.AWSSecretKey, currentCreds.AWSSessionToken, currentCreds.Expires} }); err != nil { // var oe *smithy.OperationError @@ -130,10 +98,6 @@ func IsValid(ctx context.Context, currentCreds *AWSCredentials, reloadBeforeTime return !ReloadBeforeExpiry(currentCreds.Expires, reloadBeforeTime), nil } -type authWebTokenApi interface { - AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) -} - // LoginAwsWebToken func LoginAwsWebToken(ctx context.Context, username string, svc authWebTokenApi) (*AWSCredentials, error) { // var role string @@ -167,6 +131,23 @@ func LoginAwsWebToken(ctx context.Context, username string, svc authWebTokenApi) }, nil } +// AssumeRoleInChain loops over all the roles provided +// If none are provided it will return the baseCreds +func AssumeRoleInChain(ctx context.Context, baseCreds *AWSCredentials, svc AuthSamlApi, username string, roles []string, conf CredentialConfig) (*AWSCredentials, error) { + duration := int32(900) + for idx, r := range roles { + if len(roles) == idx+1 { + duration = int32(conf.Duration) + } + c, err := assumeRoleWithCreds(ctx, baseCreds, svc, username, r, duration) + if err != nil { + return nil, err + } + baseCreds = c + } + return baseCreds, nil +} + // AssumeRoleWithCreds uses existing creds retrieved from anywhere // to pass to a credential provider and assume a specific role // @@ -199,19 +180,11 @@ func assumeRoleWithCreds(ctx context.Context, currentCreds *AWSCredentials, svc }, nil } -// AssumeRoleInChain loops over all the roles provided -// If none are provided it will return the baseCreds -func AssumeRoleInChain(ctx context.Context, baseCreds *AWSCredentials, svc AuthSamlApi, username string, roles []string, conf CredentialConfig) (*AWSCredentials, error) { - duration := int32(900) - for idx, r := range roles { - if len(roles) == idx+1 { - duration = int32(conf.Duration) - } - c, err := assumeRoleWithCreds(ctx, baseCreds, svc, username, r, duration) - if err != nil { - return nil, err - } - baseCreds = c - } - return baseCreds, nil +type credsProvider struct { + accessKey, secretKey, sessionToken string + expiry time.Time +} + +func (c *credsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { + return aws.Credentials{AccessKeyID: c.accessKey, SecretAccessKey: c.secretKey, SessionToken: c.sessionToken, CanExpire: true, Expires: c.expiry}, nil } From bc161cd49a2dbd7069b5afdc008de6c03d4abe3a Mon Sep 17 00:00:00 2001 From: dnitsch Date: Wed, 7 Jan 2026 20:55:20 +0000 Subject: [PATCH 5/6] fix: remove comment REVERT: when needed we can bring back struct for credentialexchange --- .../credentialexchange/credentialexchange.go | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/internal/credentialexchange/credentialexchange.go b/internal/credentialexchange/credentialexchange.go index 5588d56..03ce845 100755 --- a/internal/credentialexchange/credentialexchange.go +++ b/internal/credentialexchange/credentialexchange.go @@ -30,20 +30,6 @@ type authWebTokenApi interface { AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) } -// type CredentialExchange struct { -// logger zerolog.Logger -// samlSvc AuthSamlApi -// specificSvc authWebTokenApi -// } - -// func New(logger zerolog.Logger, samlSvc AuthSamlApi, specificSvc authWebTokenApi) *CredentialExchange { -// return &CredentialExchange{ -// logger: logger, -// samlSvc: samlSvc, -// specificSvc: specificSvc, -// } -// } - // LoginStsSaml exchanges saml response for STS creds func LoginStsSaml(ctx context.Context, samlResponse string, role AWSRole, svc AuthSamlApi) (*AWSCredentials, error) { @@ -54,10 +40,6 @@ func LoginStsSaml(ctx context.Context, samlResponse string, role AWSRole, svc Au DurationSeconds: aws.Int32(int32(role.Duration)), } - // unsetting the AWS_PROFILE here as we want to assume using samlResp credentials - // - // if profile is set the credential provider fails to cascade back to `[default]` section in ~/.aws/config - // os.Unsetenv("AWS_PROFILE") resp, err := svc.AssumeRoleWithSAML(ctx, params) if err != nil { return nil, fmt.Errorf("%w, failed to retrieve STS credentials using SAML: %s", ErrUnableAssume, err.Error()) From 93c20c4cfb881d0f3969b6e01bb14c0ef3eac3a1 Mon Sep 17 00:00:00 2001 From: dnitsch Date: Sat, 31 Jan 2026 14:39:32 +0000 Subject: [PATCH 6/6] fix: layout of tests in cmdutils --- cmd/awscliauth.go | 5 +- cmd/saml.go | 4 +- cmd/specific.go | 7 +- go.mod | 3 +- go.sum | 2 + internal/cmdutils/cmdutils.go | 28 +-- internal/cmdutils/cmdutils_test.go | 174 +++++------------- .../credentialexchange/credentialexchange.go | 47 ++++- .../credentialexchange_test.go | 121 ++++++------ 9 files changed, 172 insertions(+), 219 deletions(-) diff --git a/cmd/awscliauth.go b/cmd/awscliauth.go index 0e46063..62849f0 100755 --- a/cmd/awscliauth.go +++ b/cmd/awscliauth.go @@ -10,6 +10,7 @@ import ( "github.com/DevLabFoundry/aws-cli-auth/internal/credentialexchange" "github.com/Ensono/eirctl/selfupdate" "github.com/rs/zerolog" + "github.com/savioxavier/termlink" "github.com/spf13/cobra" ) @@ -60,8 +61,8 @@ Stores them under the $HOME/.aws/credentials file under a specified path or retu r.Cmd.PersistentFlags().StringVarP(&rf.CfgSectionName, "cfg-section", "", "", "Config section name to use in the look up of the config ini file (~/.aws-cli-auth.ini) and in the AWS credentials file") // When specifying store in profile the config section name must be provided r.Cmd.MarkFlagsRequiredTogether("store-profile", "cfg-section") - r.Cmd.PersistentFlags().IntVarP(&rf.Duration, "max-duration", "d", 900, `Override default max session duration, in seconds, of the role session [900-43200]. -NB: This cannot be higher than the 3600 as the API does not allow for AssumeRole for sessions longer than an hour`) + r.Cmd.PersistentFlags().IntVarP(&rf.Duration, "max-duration", "d", 900, fmt.Sprintf("Override default max session duration, in seconds, of the role session [900-43200].\nNB: This cannot be higher than the 3600 as the API does not allow for AssumeRole for sessions longer than an hour\nMore info on this and especially around role-chaining\nSee %s", + termlink.Link("AWS SDK Duration", "https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html#API_AssumeRole_RequestParameters"))) r.Cmd.PersistentFlags().BoolVarP(&rf.Verbose, "verbose", "v", false, "Verbose output") r.Cmd.PersistentFlags().StringVarP(&rf.CustomIniLocation, "config-file", "c", "", "Specify the custom location of config file") diff --git a/cmd/saml.go b/cmd/saml.go index f20bdf1..ce4c48f 100755 --- a/cmd/saml.go +++ b/cmd/saml.go @@ -123,15 +123,17 @@ func newSamlCmd(r *Root) { } if cfg.Region == "" { + // cfg.en return fmt.Errorf("unable to deduce AWS region, AWS_REGION, AWS_DEFAULT_REGION, ~/.aws/config default or profile level region must be set") } svc := sts.NewFromConfig(cfg) + cre := credentialexchange.New(r.logger, svc) webConfig := web.NewWebConf(r.Datadir). WithTimeout(flags.SamlTimeout). WithCustomExecutable(conf.BaseConfig.BrowserExecutablePath) - return cmdutils.GetCredsWebUI(ctx, svc, secretStore, *conf, webConfig) + return cmdutils.GetCredsWebUI(ctx, cre, secretStore, *conf, webConfig) }, PreRunE: func(cmd *cobra.Command, args []string) error { diff --git a/cmd/specific.go b/cmd/specific.go index 1665418..be7f1bc 100644 --- a/cmd/specific.go +++ b/cmd/specific.go @@ -39,15 +39,16 @@ Returns the same JSON object as the call to the AWS CLI for any of the sts Assum svc := sts.NewFromConfig(cfg) user, err := user.Current() - if err != nil { return err } + cre := credentialexchange.New(r.logger, svc) + if flags.method != "" { switch flags.method { case "WEB_ID": - awsCreds, err = credentialexchange.LoginAwsWebToken(ctx, user.Name, svc) + awsCreds, err = cre.LoginAwsWebToken(ctx, user.Name) if err != nil { return err } @@ -69,7 +70,7 @@ Returns the same JSON object as the call to the AWS CLI for any of the sts Assum Duration: r.rootFlags.Duration, } - awsCreds, err = credentialexchange.AssumeRoleInChain(ctx, awsCreds, svc, config.BaseConfig.Username, config.BaseConfig.RoleChain, conf) + awsCreds, err = cre.AssumeRoleInChain(ctx, awsCreds, config.BaseConfig.Username, config.BaseConfig.RoleChain, conf) if err != nil { return err } diff --git a/go.mod b/go.mod index 9cd8fd2..c96b6d5 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,8 @@ require ( github.com/aws/smithy-go v1.23.2 github.com/go-rod/rod v0.116.2 github.com/rezakhademix/govalidator/v2 v2.1.2 + github.com/rs/zerolog v1.34.0 + github.com/savioxavier/termlink v1.4.3 github.com/spf13/cobra v1.10.1 github.com/werf/lockgate v0.1.1 github.com/zalando/go-keyring v0.2.6 @@ -20,7 +22,6 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/rs/zerolog v1.34.0 // indirect github.com/schollz/progressbar/v3 v3.18.0 // indirect golang.org/x/term v0.37.0 // indirect ) diff --git a/go.sum b/go.sum index 03fe12e..fc7250c 100644 --- a/go.sum +++ b/go.sum @@ -89,6 +89,8 @@ github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/savioxavier/termlink v1.4.3 h1:Gh6vrG7jSn21cRiYdQqFXYcdXfM+Fg14aG487JTfKpA= +github.com/savioxavier/termlink v1.4.3/go.mod h1:5T5ePUlWbxCHIwyF8/Ez1qufOoGM89RCg9NvG+3G3gc= github.com/schollz/progressbar/v3 v3.18.0 h1:uXdoHABRFmNIjUfte/Ex7WtuyVslrw2wVPQmCN62HpA= github.com/schollz/progressbar/v3 v3.18.0/go.mod h1:IsO3lpbaGuzh8zIMzgY3+J8l4C8GjO0Y9S69eFvNsec= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= diff --git a/internal/cmdutils/cmdutils.go b/internal/cmdutils/cmdutils.go index af005b6..19ef66e 100644 --- a/internal/cmdutils/cmdutils.go +++ b/internal/cmdutils/cmdutils.go @@ -21,8 +21,14 @@ type SecretStorageImpl interface { SaveAWSCredential(cred *credentialexchange.AWSCredentials) error } +type CredentialExchangeImpl interface { + IsValid(ctx context.Context, currentCreds *credentialexchange.AWSCredentials, reloadBeforeTime int) (bool, error) + LoginStsSaml(ctx context.Context, samlResponse string, role credentialexchange.AWSRole) (*credentialexchange.AWSCredentials, error) + AssumeRoleInChain(ctx context.Context, baseCreds *credentialexchange.AWSCredentials, username string, roles []string, conf credentialexchange.CredentialConfig) (*credentialexchange.AWSCredentials, error) +} + // GetCredsWebUI -func GetCredsWebUI(ctx context.Context, svc credentialexchange.AuthSamlApi, secretStore SecretStorageImpl, conf credentialexchange.CredentialConfig, webConfig *web.WebConfig) error { +func GetCredsWebUI(ctx context.Context, creSvc CredentialExchangeImpl, secretStore SecretStorageImpl, conf credentialexchange.CredentialConfig, webConfig *web.WebConfig) error { if conf.BaseConfig.CfgSectionName == "" && conf.BaseConfig.StoreInProfile { return fmt.Errorf("Config-Section name must be provided if store-profile is enabled %w", ErrMissingArg) } @@ -33,7 +39,7 @@ func GetCredsWebUI(ctx context.Context, svc credentialexchange.AuthSamlApi, secr return err } - credsValid, err := credentialexchange.IsValid(ctx, storedCreds, conf.BaseConfig.ReloadBeforeTime, svc) + credsValid, err := creSvc.IsValid(ctx, storedCreds, conf.BaseConfig.ReloadBeforeTime) if err != nil { return fmt.Errorf("failed to validate: %s, %w", err, ErrUnableToValidate) } @@ -41,9 +47,9 @@ func GetCredsWebUI(ctx context.Context, svc credentialexchange.AuthSamlApi, secr if !credsValid { // TODO: delete from keychain first if conf.IsSso { - return refreshAwsSsoCreds(ctx, conf, secretStore, svc, webConfig) + return refreshAwsSsoCreds(ctx, conf, secretStore, creSvc, webConfig) } - return refreshSamlCreds(ctx, conf, secretStore, svc, webConfig) + return refreshSamlCreds(ctx, conf, secretStore, creSvc, webConfig) } return credentialexchange.SetCredentials(storedCreds, conf) @@ -52,7 +58,7 @@ func GetCredsWebUI(ctx context.Context, svc credentialexchange.AuthSamlApi, secr // refreshAwsSsoCreds uses the temp user credentials returned via AWS SSO, // upon successful auth from the IDP. // Once credentials are captured they are used in the role assumption process. -func refreshAwsSsoCreds(ctx context.Context, conf credentialexchange.CredentialConfig, secretStore SecretStorageImpl, svc credentialexchange.AuthSamlApi, webConfig *web.WebConfig) error { +func refreshAwsSsoCreds(ctx context.Context, conf credentialexchange.CredentialConfig, secretStore SecretStorageImpl, creSvc CredentialExchangeImpl, webConfig *web.WebConfig) error { webBrowser, err := web.New(ctx, webConfig) if err != nil { return err @@ -63,10 +69,10 @@ func refreshAwsSsoCreds(ctx context.Context, conf credentialexchange.CredentialC } awsCreds := &credentialexchange.AWSCredentials{} _, _ = awsCreds.FromRoleCredString(capturedCreds) - return completeCredProcess(ctx, secretStore, svc, awsCreds, conf) + return completeCredProcess(ctx, secretStore, creSvc, awsCreds, conf) } -func refreshSamlCreds(ctx context.Context, conf credentialexchange.CredentialConfig, secretStore SecretStorageImpl, svc credentialexchange.AuthSamlApi, webConfig *web.WebConfig) error { +func refreshSamlCreds(ctx context.Context, conf credentialexchange.CredentialConfig, secretStore SecretStorageImpl, creSvc CredentialExchangeImpl, webConfig *web.WebConfig) error { webBrowser, err := web.New(ctx, webConfig) if err != nil { @@ -95,15 +101,15 @@ func refreshSamlCreds(ctx context.Context, conf credentialexchange.CredentialCon Duration: duration, } - awsCreds, err := credentialexchange.LoginStsSaml(ctx, samlResp, roleObj, svc) + awsCreds, err := creSvc.LoginStsSaml(ctx, samlResp, roleObj) if err != nil { return err } - return completeCredProcess(ctx, secretStore, svc, awsCreds, conf) + return completeCredProcess(ctx, secretStore, creSvc, awsCreds, conf) } -func completeCredProcess(ctx context.Context, secretStore SecretStorageImpl, svc credentialexchange.AuthSamlApi, awsCreds *credentialexchange.AWSCredentials, conf credentialexchange.CredentialConfig) error { - creds, err := credentialexchange.AssumeRoleInChain(ctx, awsCreds, svc, conf.BaseConfig.Username, conf.BaseConfig.RoleChain, conf) +func completeCredProcess(ctx context.Context, secretStore SecretStorageImpl, creSvc CredentialExchangeImpl, awsCreds *credentialexchange.AWSCredentials, conf credentialexchange.CredentialConfig) error { + creds, err := creSvc.AssumeRoleInChain(ctx, awsCreds, conf.BaseConfig.Username, conf.BaseConfig.RoleChain, conf) if err != nil { return err } diff --git a/internal/cmdutils/cmdutils_test.go b/internal/cmdutils/cmdutils_test.go index 5dd77e2..ddbd10c 100644 --- a/internal/cmdutils/cmdutils_test.go +++ b/internal/cmdutils/cmdutils_test.go @@ -13,9 +13,6 @@ import ( "github.com/DevLabFoundry/aws-cli-auth/internal/cmdutils" "github.com/DevLabFoundry/aws-cli-auth/internal/credentialexchange" "github.com/DevLabFoundry/aws-cli-auth/internal/web" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/sts" - "github.com/aws/aws-sdk-go-v2/service/sts/types" "gopkg.in/ini.v1" ) @@ -145,22 +142,31 @@ func testConfig() credentialexchange.CredentialConfig { } } -type mockAuthApi struct { - assumeRoleWSaml func(ctx context.Context, params *sts.AssumeRoleWithSAMLInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithSAMLOutput, error) - getCallId func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) - assume func(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) +type mockCredExchangeApi struct { + isValid func(ctx context.Context, currentCreds *credentialexchange.AWSCredentials, reloadBeforeTime int) (bool, error) + loginStsSaml func(ctx context.Context, samlResponse string, role credentialexchange.AWSRole) (*credentialexchange.AWSCredentials, error) + assumeRoleInChain func(ctx context.Context, baseCreds *credentialexchange.AWSCredentials, username string, roles []string, conf credentialexchange.CredentialConfig) (*credentialexchange.AWSCredentials, error) } -func (m *mockAuthApi) AssumeRoleWithSAML(ctx context.Context, params *sts.AssumeRoleWithSAMLInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithSAMLOutput, error) { - return m.assumeRoleWSaml(ctx, params, optFns...) +func (m *mockCredExchangeApi) IsValid(ctx context.Context, currentCreds *credentialexchange.AWSCredentials, reloadBeforeTime int) (bool, error) { + if m.isValid != nil { + return m.isValid(ctx, currentCreds, reloadBeforeTime) + } + return false, nil } -func (m *mockAuthApi) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { - return m.getCallId(ctx, params, optFns...) +func (m *mockCredExchangeApi) LoginStsSaml(ctx context.Context, samlResponse string, role credentialexchange.AWSRole) (*credentialexchange.AWSCredentials, error) { + if m.loginStsSaml != nil { + return m.loginStsSaml(ctx, samlResponse, role) + } + return &credentialexchange.AWSCredentials{}, nil } -func (m *mockAuthApi) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { - return m.assume(ctx, params, optFns...) +func (m *mockCredExchangeApi) AssumeRoleInChain(ctx context.Context, baseCreds *credentialexchange.AWSCredentials, username string, roles []string, conf credentialexchange.CredentialConfig) (*credentialexchange.AWSCredentials, error) { + if m.assumeRoleInChain != nil { + return m.assumeRoleInChain(ctx, baseCreds, username, roles, conf) + } + return &credentialexchange.AWSCredentials{}, nil } type mockSecretApi struct { @@ -188,64 +194,20 @@ func (s *mockSecretApi) SaveAWSCredential(cred *credentialexchange.AWSCredential func Test_GetSamlCreds_With(t *testing.T) { ttests := map[string]struct { - config func(t *testing.T) credentialexchange.CredentialConfig - handler func(t *testing.T, awsMock bool) http.Handler - authApi func(t *testing.T) credentialexchange.AuthSamlApi - secretStore func(t *testing.T) cmdutils.SecretStorageImpl - expectErr bool - errTyp error + config func(t *testing.T) credentialexchange.CredentialConfig + handler func(t *testing.T, awsMock bool) http.Handler + credExchange func(t *testing.T) cmdutils.CredentialExchangeImpl + secretStore func(t *testing.T) cmdutils.SecretStorageImpl + expectErr bool + errTyp error }{ "correct config and extracted creds but not valid anymore": { config: func(t *testing.T) credentialexchange.CredentialConfig { return testConfig() }, handler: IdpHandler, - authApi: func(t *testing.T) credentialexchange.AuthSamlApi { - m := &mockAuthApi{} - m.assumeRoleWSaml = func(ctx context.Context, params *sts.AssumeRoleWithSAMLInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithSAMLOutput, error) { - return &sts.AssumeRoleWithSAMLOutput{ - AssumedRoleUser: &types.AssumedRoleUser{ - AssumedRoleId: aws.String("some-role"), - Arn: aws.String("arn"), - }, - Audience: new(string), - Credentials: &types.Credentials{ - AccessKeyId: aws.String("123213"), - SecretAccessKey: aws.String("32798hewf"), - SessionToken: aws.String("49hefusdSOM_LONG_TOKEN_HERE"), - Expiration: aws.Time(time.Now().Local().Add(time.Minute * time.Duration(5))), - }, - Issuer: new(string), - NameQualifier: new(string), - PackedPolicySize: new(int32), - SourceIdentity: new(string), - Subject: new(string), - SubjectType: new(string), - }, nil - } - - m.getCallId = func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { - // t.Error() - return &sts.GetCallerIdentityOutput{ - Account: aws.String("1122223334"), - Arn: aws.String("arn:aws:iam::1122223334:role/some-role"), - UserId: aws.String("some-user-id"), - }, nil - } - m.assume = func(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { - return &sts.AssumeRoleOutput{ - AssumedRoleUser: &types.AssumedRoleUser{ - AssumedRoleId: aws.String("some-role"), - Arn: aws.String("arn"), - }, - Credentials: &types.Credentials{ - AccessKeyId: aws.String("123213"), - SecretAccessKey: aws.String("32798hewf"), - SessionToken: aws.String("49hefusdSOM_LONG_TOKEN_HERE"), - Expiration: aws.Time(time.Now().Local().Add(time.Minute * time.Duration(5))), - }, - }, nil - } + credExchange: func(t *testing.T) cmdutils.CredentialExchangeImpl { + m := &mockCredExchangeApi{} return m }, secretStore: func(t *testing.T) cmdutils.SecretStorageImpl { @@ -274,37 +236,10 @@ func Test_GetSamlCreds_With(t *testing.T) { return conf }, handler: IdpHandler, - authApi: func(t *testing.T) credentialexchange.AuthSamlApi { - m := &mockAuthApi{} - m.assumeRoleWSaml = func(ctx context.Context, params *sts.AssumeRoleWithSAMLInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithSAMLOutput, error) { - return &sts.AssumeRoleWithSAMLOutput{ - AssumedRoleUser: &types.AssumedRoleUser{ - AssumedRoleId: aws.String("some-role"), - Arn: aws.String("arn"), - }, - Audience: new(string), - Credentials: &types.Credentials{ - AccessKeyId: aws.String("123213"), - SecretAccessKey: aws.String("32798hewf"), - SessionToken: aws.String("49hefusdSOM_LONG_TOKEN_HERE"), - Expiration: aws.Time(time.Now().Local().Add(time.Minute * time.Duration(5))), - }, - Issuer: new(string), - NameQualifier: new(string), - PackedPolicySize: new(int32), - SourceIdentity: new(string), - Subject: new(string), - SubjectType: new(string), - }, nil - } - - m.getCallId = func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { - // t.Error() - return &sts.GetCallerIdentityOutput{ - Account: aws.String("1122223334"), - Arn: aws.String("arn:aws:iam::1122223334:role/some-role"), - UserId: aws.String("some-user-id"), - }, nil + credExchange: func(t *testing.T) cmdutils.CredentialExchangeImpl { + m := &mockCredExchangeApi{} + m.isValid = func(ctx context.Context, currentCreds *credentialexchange.AWSCredentials, reloadBeforeTime int) (bool, error) { + return true, nil } return m @@ -336,8 +271,8 @@ func Test_GetSamlCreds_With(t *testing.T) { return tc }, handler: IdpHandler, - authApi: func(t *testing.T) credentialexchange.AuthSamlApi { - return &mockAuthApi{} + credExchange: func(t *testing.T) cmdutils.CredentialExchangeImpl { + return &mockCredExchangeApi{} }, secretStore: func(t *testing.T) cmdutils.SecretStorageImpl { ss := &mockSecretApi{} @@ -361,8 +296,8 @@ func Test_GetSamlCreds_With(t *testing.T) { return tc }, handler: IdpHandler, - authApi: func(t *testing.T) credentialexchange.AuthSamlApi { - return &mockAuthApi{} + credExchange: func(t *testing.T) cmdutils.CredentialExchangeImpl { + return &mockCredExchangeApi{} }, secretStore: func(t *testing.T) cmdutils.SecretStorageImpl { ss := &mockSecretApi{} @@ -382,12 +317,11 @@ func Test_GetSamlCreds_With(t *testing.T) { return tc }, handler: IdpHandler, - authApi: func(t *testing.T) credentialexchange.AuthSamlApi { - m := &mockAuthApi{} - m.getCallId = func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { - return nil, fmt.Errorf("get caller error") + credExchange: func(t *testing.T) cmdutils.CredentialExchangeImpl { + m := &mockCredExchangeApi{} + m.isValid = func(ctx context.Context, currentCreds *credentialexchange.AWSCredentials, reloadBeforeTime int) (bool, error) { + return false, fmt.Errorf("unable to validate") } - return m }, secretStore: func(t *testing.T) cmdutils.SecretStorageImpl { @@ -426,7 +360,7 @@ func Test_GetSamlCreds_With(t *testing.T) { ss := tt.secretStore(t) err := cmdutils.GetCredsWebUI( - context.TODO(), tt.authApi(t), ss, conf, + context.TODO(), tt.credExchange(t), ss, conf, web.NewWebConf(tempDir).WithHeadless().WithTimeout(10).WithNoSandbox()) if tt.expectErr { @@ -481,7 +415,7 @@ func Test_Get_SSO_Creds_with(t *testing.T) { ttests := map[string]struct { config func(t *testing.T) credentialexchange.CredentialConfig handler func(t *testing.T) http.Handler - authApi func(t *testing.T) credentialexchange.AuthSamlApi + authApi func(t *testing.T) cmdutils.CredentialExchangeImpl secretStore func(t *testing.T) cmdutils.SecretStorageImpl expectErr bool errTyp error @@ -491,30 +425,8 @@ func Test_Get_SSO_Creds_with(t *testing.T) { return testConfig() }, handler: mockSsoHandler, - authApi: func(t *testing.T) credentialexchange.AuthSamlApi { - m := &mockAuthApi{} - m.assumeRoleWSaml = func(ctx context.Context, params *sts.AssumeRoleWithSAMLInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithSAMLOutput, error) { - return &sts.AssumeRoleWithSAMLOutput{}, nil - } - - m.getCallId = func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { - // t.Error() - return &sts.GetCallerIdentityOutput{}, nil - } - m.assume = func(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { - return &sts.AssumeRoleOutput{ - AssumedRoleUser: &types.AssumedRoleUser{ - AssumedRoleId: aws.String("some-role"), - Arn: aws.String("arn"), - }, - Credentials: &types.Credentials{ - AccessKeyId: aws.String("123213"), - SecretAccessKey: aws.String("32798hewf"), - SessionToken: aws.String("49hefusdSOM_LONG_TOKEN_HERE"), - Expiration: aws.Time(time.Now().Local().Add(time.Minute * time.Duration(5))), - }, - }, nil - } + authApi: func(t *testing.T) cmdutils.CredentialExchangeImpl { + m := &mockCredExchangeApi{} return m }, secretStore: func(t *testing.T) cmdutils.SecretStorageImpl { diff --git a/internal/credentialexchange/credentialexchange.go b/internal/credentialexchange/credentialexchange.go index 03ce845..4cc85ce 100755 --- a/internal/credentialexchange/credentialexchange.go +++ b/internal/credentialexchange/credentialexchange.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/smithy-go" + "github.com/rs/zerolog" ) var ( @@ -30,8 +31,25 @@ type authWebTokenApi interface { AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) } +type iamSvcIface interface { + authWebTokenApi + AuthSamlApi +} + +type CredentialExchange struct { + logger zerolog.Logger + svc iamSvcIface +} + +func New(logger zerolog.Logger, svc iamSvcIface) *CredentialExchange { + return &CredentialExchange{ + logger: logger, + svc: svc, + } +} + // LoginStsSaml exchanges saml response for STS creds -func LoginStsSaml(ctx context.Context, samlResponse string, role AWSRole, svc AuthSamlApi) (*AWSCredentials, error) { +func (c *CredentialExchange) LoginStsSaml(ctx context.Context, samlResponse string, role AWSRole) (*AWSCredentials, error) { params := &sts.AssumeRoleWithSAMLInput{ PrincipalArn: aws.String(role.PrincipalARN), // Required @@ -40,7 +58,7 @@ func LoginStsSaml(ctx context.Context, samlResponse string, role AWSRole, svc Au DurationSeconds: aws.Int32(int32(role.Duration)), } - resp, err := svc.AssumeRoleWithSAML(ctx, params) + resp, err := c.svc.AssumeRoleWithSAML(ctx, params) if err != nil { return nil, fmt.Errorf("%w, failed to retrieve STS credentials using SAML: %s", ErrUnableAssume, err.Error()) } @@ -58,12 +76,13 @@ func LoginStsSaml(ctx context.Context, samlResponse string, role AWSRole, svc Au // returns them if they are still valid // if reloadTimeBefore is less than time left on the creds // then it will re-request a login -func IsValid(ctx context.Context, currentCreds *AWSCredentials, reloadBeforeTime int, svc AuthSamlApi) (bool, error) { +func (c *CredentialExchange) IsValid(ctx context.Context, currentCreds *AWSCredentials, reloadBeforeTime int) (bool, error) { if currentCreds == nil { return false, nil } - if _, err := svc.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}, func(o *sts.Options) { + if _, err := c.svc.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}, func(o *sts.Options) { + // o.EndpointResolverV2 = &resolverProvider{} //.ResolveEndpoint(ctx, sts.EndpointParameters{}) o.Credentials = &credsProvider{currentCreds.AWSAccessKey, currentCreds.AWSSecretKey, currentCreds.AWSSessionToken, currentCreds.Expires} }); err != nil { // var oe *smithy.OperationError @@ -81,7 +100,7 @@ func IsValid(ctx context.Context, currentCreds *AWSCredentials, reloadBeforeTime } // LoginAwsWebToken -func LoginAwsWebToken(ctx context.Context, username string, svc authWebTokenApi) (*AWSCredentials, error) { +func (c *CredentialExchange) LoginAwsWebToken(ctx context.Context, username string) (*AWSCredentials, error) { // var role string r, exists := os.LookupEnv(AWS_ROLE_ARN) if !exists { @@ -99,7 +118,7 @@ func LoginAwsWebToken(ctx context.Context, username string, svc authWebTokenApi) WebIdentityToken: &token, } - resp, err := svc.AssumeRoleWithWebIdentity(ctx, input) + resp, err := c.svc.AssumeRoleWithWebIdentity(ctx, input) if err != nil { return nil, fmt.Errorf("failed to retrieve STS credentials using token file: %s, %w", err.Error(), ErrUnableAssume) } @@ -115,13 +134,13 @@ func LoginAwsWebToken(ctx context.Context, username string, svc authWebTokenApi) // AssumeRoleInChain loops over all the roles provided // If none are provided it will return the baseCreds -func AssumeRoleInChain(ctx context.Context, baseCreds *AWSCredentials, svc AuthSamlApi, username string, roles []string, conf CredentialConfig) (*AWSCredentials, error) { +func (c *CredentialExchange) AssumeRoleInChain(ctx context.Context, baseCreds *AWSCredentials, username string, roles []string, conf CredentialConfig) (*AWSCredentials, error) { duration := int32(900) for idx, r := range roles { if len(roles) == idx+1 { duration = int32(conf.Duration) } - c, err := assumeRoleWithCreds(ctx, baseCreds, svc, username, r, duration) + c, err := c.assumeRoleWithCreds(ctx, baseCreds, username, r, duration) if err != nil { return nil, err } @@ -135,7 +154,7 @@ func AssumeRoleInChain(ctx context.Context, baseCreds *AWSCredentials, svc AuthS // // Most common use case is role chaining an WeBId role to a specific one // duration is the -func assumeRoleWithCreds(ctx context.Context, currentCreds *AWSCredentials, svc AuthSamlApi, username, role string, duration int32) (*AWSCredentials, error) { +func (c *CredentialExchange) assumeRoleWithCreds(ctx context.Context, currentCreds *AWSCredentials, username, role string, duration int32) (*AWSCredentials, error) { timeNowPlusDuration := time.Now().Add(time.Duration(duration) * time.Second) @@ -145,7 +164,8 @@ func assumeRoleWithCreds(ctx context.Context, currentCreds *AWSCredentials, svc // DurationSeconds: &duration, } - roleCreds, err := svc.AssumeRole(ctx, input, func(o *sts.Options) { + c.logger.Debug().Any("timeNowPlusDuration", timeNowPlusDuration).Msgf("") + roleCreds, err := c.svc.AssumeRole(ctx, input, func(o *sts.Options) { o.Credentials = &credsProvider{currentCreds.AWSAccessKey, currentCreds.AWSSecretKey, currentCreds.AWSSessionToken, currentCreds.Expires} }) @@ -170,3 +190,10 @@ type credsProvider struct { func (c *credsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { return aws.Credentials{AccessKeyID: c.accessKey, SecretAccessKey: c.secretKey, SessionToken: c.sessionToken, CanExpire: true, Expires: c.expiry}, nil } + +// type resolverProvider struct { +// } + +// func (c *resolverProvider) ResolveEndpoint(ctx context.Context, params sts.EndpointParameters) (transport.Endpoint, error) { +// return transport.Endpoint{}, nil +// } diff --git a/internal/credentialexchange/credentialexchange_test.go b/internal/credentialexchange/credentialexchange_test.go index 5ff4dc1..038fdd3 100644 --- a/internal/credentialexchange/credentialexchange_test.go +++ b/internal/credentialexchange/credentialexchange_test.go @@ -14,26 +14,32 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/aws/smithy-go" + "github.com/rs/zerolog" ) -type mockAuthApi struct { +type mockIamSvcApi struct { + assumewithwebId func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) assumeRoleWSaml func(ctx context.Context, params *sts.AssumeRoleWithSAMLInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithSAMLOutput, error) getCallId func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) assume func(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) } -func (m *mockAuthApi) AssumeRoleWithSAML(ctx context.Context, params *sts.AssumeRoleWithSAMLInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithSAMLOutput, error) { +func (m *mockIamSvcApi) AssumeRoleWithSAML(ctx context.Context, params *sts.AssumeRoleWithSAMLInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithSAMLOutput, error) { return m.assumeRoleWSaml(ctx, params, optFns...) } -func (m *mockAuthApi) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { +func (m *mockIamSvcApi) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { return m.getCallId(ctx, params, optFns...) } -func (m *mockAuthApi) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { +func (m *mockIamSvcApi) AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { return m.assume(ctx, params, optFns...) } +func (m *mockIamSvcApi) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + return m.assumewithwebId(ctx, params, optFns...) +} + var mockSuccessAwsCreds = &types.Credentials{ AccessKeyId: aws.String("123"), SecretAccessKey: aws.String("456"), @@ -43,13 +49,13 @@ var mockSuccessAwsCreds = &types.Credentials{ func Test_AssumeWithSaml_(t *testing.T) { ttests := map[string]struct { - srv func(t *testing.T) credentialexchange.AuthSamlApi + srv func(t *testing.T) *credentialexchange.CredentialExchange expectErr bool errTyp error }{ "succeeds with correct input": { - srv: func(t *testing.T) credentialexchange.AuthSamlApi { - m := &mockAuthApi{} + srv: func(t *testing.T) *credentialexchange.CredentialExchange { + m := &mockIamSvcApi{} m.assumeRoleWSaml = func(ctx context.Context, params *sts.AssumeRoleWithSAMLInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithSAMLOutput, error) { if *params.RoleArn != "somerole" { t.Errorf("expected role: %s got: %s", "somerole", *params.RoleArn) @@ -59,21 +65,21 @@ func Test_AssumeWithSaml_(t *testing.T) { Credentials: mockSuccessAwsCreds, }, nil } - return m + return credentialexchange.New(zerolog.Nop().With().Logger(), m) }, expectErr: false, errTyp: nil, }, "fails on input": { - srv: func(t *testing.T) credentialexchange.AuthSamlApi { - m := &mockAuthApi{} + srv: func(t *testing.T) *credentialexchange.CredentialExchange { + m := &mockIamSvcApi{} m.assumeRoleWSaml = func(ctx context.Context, params *sts.AssumeRoleWithSAMLInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithSAMLOutput, error) { if *params.RoleArn != "somerole" { t.Errorf("expected role: %s got: %s", "somerole", *params.RoleArn) } return nil, fmt.Errorf("some error") } - return m + return credentialexchange.New(zerolog.Nop().With().Logger(), m) }, expectErr: true, errTyp: credentialexchange.ErrUnableAssume, @@ -81,13 +87,12 @@ func Test_AssumeWithSaml_(t *testing.T) { } for name, tt := range ttests { t.Run(name, func(t *testing.T) { - got, err := credentialexchange.LoginStsSaml(context.TODO(), "samlAssertion...372dgh8ybjsdfviwehfiu9rwfe", + got, err := tt.srv(t).LoginStsSaml(context.TODO(), "samlAssertion...372dgh8ybjsdfviwehfiu9rwfe", credentialexchange.AWSRole{ RoleARN: "somerole", PrincipalARN: "someprincipal", Duration: 900, }, - tt.srv(t), ) if tt.expectErr { @@ -139,7 +144,7 @@ func (e *smithyErrTyp) ErrorFault() smithy.ErrorFault { func Test_IsValid_with(t *testing.T) { ttests := map[string]struct { - srv func(t *testing.T) credentialexchange.AuthSamlApi + srv func(t *testing.T) *credentialexchange.CredentialExchange currCred *credentialexchange.AWSCredentials reloadBefore int expectValid bool @@ -147,15 +152,15 @@ func Test_IsValid_with(t *testing.T) { errTyp error }{ "non expired credential with enough time before reload required": { - func(t *testing.T) credentialexchange.AuthSamlApi { - m := &mockAuthApi{} + func(t *testing.T) *credentialexchange.CredentialExchange { + m := &mockIamSvcApi{} m.getCallId = func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { return &sts.GetCallerIdentityOutput{ Account: aws.String("account"), Arn: aws.String("arn"), }, nil } - return m + return credentialexchange.New(zerolog.Nop().With().Logger(), m) }, &credentialexchange.AWSCredentials{ AWSAccessKey: "stringjsonAccessKey", @@ -169,15 +174,15 @@ func Test_IsValid_with(t *testing.T) { nil, }, "credentials valid but need to reload before time fails": { - func(t *testing.T) credentialexchange.AuthSamlApi { - m := &mockAuthApi{} + func(t *testing.T) *credentialexchange.CredentialExchange { + m := &mockIamSvcApi{} m.getCallId = func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { return &sts.GetCallerIdentityOutput{ Account: aws.String("account"), Arn: aws.String("arn"), }, nil } - return m + return credentialexchange.New(zerolog.Nop().With().Logger(), m) }, &credentialexchange.AWSCredentials{ AWSAccessKey: "stringjsonAccessKey", @@ -191,15 +196,16 @@ func Test_IsValid_with(t *testing.T) { nil, }, "expired credential": { - func(t *testing.T) credentialexchange.AuthSamlApi { - m := &mockAuthApi{} + func(t *testing.T) *credentialexchange.CredentialExchange { + m := &mockIamSvcApi{} m.getCallId = func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { return nil, &smithyErrTyp{ err: func() string { return "some errr" }, errCode: func() string { return "ExpiredToken" }, } } - return m + return credentialexchange.New(zerolog.Nop().With().Logger(), m) + }, &credentialexchange.AWSCredentials{ AWSAccessKey: "stringjsonAccessKey", @@ -213,15 +219,16 @@ func Test_IsValid_with(t *testing.T) { nil, }, "another error when chekcing credential": { - func(t *testing.T) credentialexchange.AuthSamlApi { - m := &mockAuthApi{} + func(t *testing.T) *credentialexchange.CredentialExchange { + m := &mockIamSvcApi{} m.getCallId = func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { return nil, &smithyErrTyp{ err: func() string { return "some errr" }, errCode: func() string { return "SomeOTherErr" }, } } - return m + return credentialexchange.New(zerolog.Nop().With().Logger(), m) + }, &credentialexchange.AWSCredentials{ AWSAccessKey: "stringjsonAccessKey", @@ -235,9 +242,9 @@ func Test_IsValid_with(t *testing.T) { credentialexchange.ErrUnableAssume, }, "no existing credential": { - func(t *testing.T) credentialexchange.AuthSamlApi { - m := &mockAuthApi{} - return m + func(t *testing.T) *credentialexchange.CredentialExchange { + m := &mockIamSvcApi{} + return credentialexchange.New(zerolog.Nop().With().Logger(), m) }, nil, 120, @@ -248,7 +255,7 @@ func Test_IsValid_with(t *testing.T) { } for name, tt := range ttests { t.Run(name, func(t *testing.T) { - valid, err := credentialexchange.IsValid(context.TODO(), tt.currCred, tt.reloadBefore, tt.srv(t)) + valid, err := tt.srv(t).IsValid(context.TODO(), tt.currCred, tt.reloadBefore) if tt.expectErr { if err == nil { @@ -272,32 +279,24 @@ func Test_IsValid_with(t *testing.T) { } } -type authWebTokenApi struct { - assumewithwebId func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) -} - -func (a *authWebTokenApi) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { - return a.assumewithwebId(ctx, params, optFns...) -} - func Test_LoginAwsWebToken_with(t *testing.T) { ttests := map[string]struct { - srv func(t *testing.T) *authWebTokenApi + srv func(t *testing.T) *credentialexchange.CredentialExchange setup func() func() currCred *credentialexchange.AWSCredentials expectErr bool errTyp error }{ "succeeds with correct input": { - srv: func(t *testing.T) *authWebTokenApi { - a := &authWebTokenApi{} + srv: func(t *testing.T) *credentialexchange.CredentialExchange { + a := &mockIamSvcApi{} a.assumewithwebId = func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { return &sts.AssumeRoleWithWebIdentityOutput{ AssumedRoleUser: &types.AssumedRoleUser{Arn: aws.String("assumedRoleUser")}, Credentials: mockSuccessAwsCreds, }, nil } - return a + return credentialexchange.New(zerolog.Nop().With().Logger(), a) }, setup: func() func() { tmpDir, _ := os.MkdirTemp(os.TempDir(), "web-id") @@ -315,12 +314,12 @@ func Test_LoginAwsWebToken_with(t *testing.T) { errTyp: nil, }, "fails on rest call to assume": { - srv: func(t *testing.T) *authWebTokenApi { - a := &authWebTokenApi{} + srv: func(t *testing.T) *credentialexchange.CredentialExchange { + a := &mockIamSvcApi{} a.assumewithwebId = func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { return nil, fmt.Errorf("some err") } - return a + return credentialexchange.New(zerolog.Nop().With().Logger(), a) }, setup: func() func() { tmpDir, _ := os.MkdirTemp(os.TempDir(), "web-id") @@ -338,15 +337,15 @@ func Test_LoginAwsWebToken_with(t *testing.T) { errTyp: credentialexchange.ErrUnableAssume, }, "fails on missing role env VARS": { - srv: func(t *testing.T) *authWebTokenApi { - a := &authWebTokenApi{} + srv: func(t *testing.T) *credentialexchange.CredentialExchange { + a := &mockIamSvcApi{} a.assumewithwebId = func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { return &sts.AssumeRoleWithWebIdentityOutput{ AssumedRoleUser: &types.AssumedRoleUser{Arn: aws.String("assumedRoleUser")}, Credentials: mockSuccessAwsCreds, }, nil } - return a + return credentialexchange.New(zerolog.Nop().With().Logger(), a) }, setup: func() func() { return func() {} @@ -356,15 +355,16 @@ func Test_LoginAwsWebToken_with(t *testing.T) { errTyp: credentialexchange.ErrMissingEnvVar, }, "fails on missing token file env VARS": { - srv: func(t *testing.T) *authWebTokenApi { - a := &authWebTokenApi{} + srv: func(t *testing.T) *credentialexchange.CredentialExchange { + a := &mockIamSvcApi{} a.assumewithwebId = func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { return &sts.AssumeRoleWithWebIdentityOutput{ AssumedRoleUser: &types.AssumedRoleUser{Arn: aws.String("assumedRoleUser")}, Credentials: mockSuccessAwsCreds, }, nil } - return a + return credentialexchange.New(zerolog.Nop().With().Logger(), a) + }, setup: func() func() { // tmpDir, _ := os.MkdirTemp(os.TempDir(), "web-id") @@ -387,7 +387,7 @@ func Test_LoginAwsWebToken_with(t *testing.T) { tearDown := tt.setup() defer tearDown() - got, err := credentialexchange.LoginAwsWebToken(context.TODO(), "username", tt.srv(t)) + got, err := tt.srv(t).LoginAwsWebToken(context.TODO(), "username") if tt.expectErr { if err == nil { @@ -412,30 +412,31 @@ func Test_LoginAwsWebToken_with(t *testing.T) { func Test_AssumeSpecifiedCreds_with(t *testing.T) { ttests := map[string]struct { - srv func(t *testing.T) *mockAuthApi + srv func(t *testing.T) *credentialexchange.CredentialExchange currCred *credentialexchange.AWSCredentials expectErr bool errTyp error }{ "successfully passed in creds from somewhere": { - srv: func(t *testing.T) *mockAuthApi { - m := &mockAuthApi{} + srv: func(t *testing.T) *credentialexchange.CredentialExchange { + m := &mockIamSvcApi{} m.assume = func(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { return &sts.AssumeRoleOutput{ AssumedRoleUser: &types.AssumedRoleUser{Arn: aws.String("somearn")}, Credentials: mockSuccessAwsCreds, }, nil } - return m + return credentialexchange.New(zerolog.Nop().With().Logger(), m) }, }, "error on calling AssumeRole API": { - srv: func(t *testing.T) *mockAuthApi { - m := &mockAuthApi{} + srv: func(t *testing.T) *credentialexchange.CredentialExchange { + m := &mockIamSvcApi{} m.assume = func(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) { return nil, fmt.Errorf("some error") } - return m + return credentialexchange.New(zerolog.Nop().With().Logger(), m) + }, expectErr: true, errTyp: credentialexchange.ErrUnableAssume, @@ -443,7 +444,7 @@ func Test_AssumeSpecifiedCreds_with(t *testing.T) { } for name, tt := range ttests { t.Run(name, func(t *testing.T) { - got, err := credentialexchange.AssumeRoleInChain(context.TODO(), tt.currCred, tt.srv(t), "foo", []string{"barrole"}, credentialexchange.CredentialConfig{Duration: 14400}) + got, err := tt.srv(t).AssumeRoleInChain(context.TODO(), tt.currCred, "foo", []string{"barrole"}, credentialexchange.CredentialConfig{Duration: 14400}) if tt.expectErr { if err == nil {