Skip to content

Commit a645d6d

Browse files
authored
Implement client ID as config option (#449)
* implement client ID as config option * clean up comment
1 parent faf5e5f commit a645d6d

File tree

10 files changed

+103
-4
lines changed

10 files changed

+103
-4
lines changed

docs/stackit_config_set.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ stackit config set [flags]
3434
--dns-custom-endpoint string DNS API base URL, used in calls to this API
3535
-h, --help Help for "stackit config set"
3636
--iaas-custom-endpoint string IaaS API base URL, used in calls to this API
37+
--identity-provider-custom-client-id string Identity Provider client ID, used for user authentication
3738
--identity-provider-custom-endpoint string Identity Provider base URL, used for user authentication
3839
--load-balancer-custom-endpoint string Load Balancer API base URL, used in calls to this API
3940
--logme-custom-endpoint string LogMe API base URL, used in calls to this API

docs/stackit_config_unset.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ stackit config unset [flags]
3232
--dns-custom-endpoint DNS API base URL. If unset, uses the default base URL
3333
-h, --help Help for "stackit config unset"
3434
--iaas-custom-endpoint IaaS API base URL. If unset, uses the default base URL
35+
--identity-provider-custom-client-id Identity Provider client ID, used for user authentication
3536
--identity-provider-custom-endpoint Identity Provider base URL. If unset, uses the default base URL
3637
--load-balancer-custom-endpoint Load Balancer API base URL. If unset, uses the default base URL
3738
--logme-custom-endpoint LogMe API base URL. If unset, uses the default base URL

internal/cmd/config/set/set.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
const (
2020
sessionTimeLimitFlag = "session-time-limit"
2121
identityProviderCustomEndpointFlag = "identity-provider-custom-endpoint"
22+
identityProviderCustomClientIdFlag = "identity-provider-custom-client-id"
2223

2324
argusCustomEndpointFlag = "argus-custom-endpoint"
2425
authorizationCustomEndpointFlag = "authorization-custom-endpoint"
@@ -129,6 +130,7 @@ Use "{{.CommandPath}} [command] --help" for more information about a command.{{e
129130
func configureFlags(cmd *cobra.Command) {
130131
cmd.Flags().String(sessionTimeLimitFlag, "", "Maximum time before authentication is required again. After this time, you will be prompted to login again to execute commands that require authentication. Can't be larger than 24h. Requires authentication after being set to take effect. Examples: 3h, 5h30m40s (BETA: currently values greater than 2h have no effect)")
131132
cmd.Flags().String(identityProviderCustomEndpointFlag, "", "Identity Provider base URL, used for user authentication")
133+
cmd.Flags().String(identityProviderCustomClientIdFlag, "", "Identity Provider client ID, used for user authentication")
132134
cmd.Flags().String(argusCustomEndpointFlag, "", "Argus API base URL, used in calls to this API")
133135
cmd.Flags().String(authorizationCustomEndpointFlag, "", "Authorization API base URL, used in calls to this API")
134136
cmd.Flags().String(dnsCustomEndpointFlag, "", "DNS API base URL, used in calls to this API")
@@ -155,6 +157,8 @@ func configureFlags(cmd *cobra.Command) {
155157
cobra.CheckErr(err)
156158
err = viper.BindPFlag(config.IdentityProviderCustomEndpointKey, cmd.Flags().Lookup(identityProviderCustomEndpointFlag))
157159
cobra.CheckErr(err)
160+
err = viper.BindPFlag(config.IdentityProviderCustomClientIdKey, cmd.Flags().Lookup(identityProviderCustomClientIdFlag))
161+
cobra.CheckErr(err)
158162

159163
err = viper.BindPFlag(config.ArgusCustomEndpointKey, cmd.Flags().Lookup(argusCustomEndpointFlag))
160164
cobra.CheckErr(err)

internal/cmd/config/unset/unset.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ const (
2222

2323
sessionTimeLimitFlag = "session-time-limit"
2424
identityProviderCustomEndpointFlag = "identity-provider-custom-endpoint"
25+
identityProviderCustomClientIdFlag = "identity-provider-custom-client-id"
2526

2627
argusCustomEndpointFlag = "argus-custom-endpoint"
2728
authorizationCustomEndpointFlag = "authorization-custom-endpoint"
@@ -54,6 +55,7 @@ type inputModel struct {
5455

5556
SessionTimeLimit bool
5657
IdentityProviderCustomEndpoint bool
58+
IdentityProviderCustomClientID bool
5759

5860
ArgusCustomEndpoint bool
5961
AuthorizationCustomEndpoint bool
@@ -117,6 +119,9 @@ func NewCmd(p *print.Printer) *cobra.Command {
117119
if model.IdentityProviderCustomEndpoint {
118120
viper.Set(config.IdentityProviderCustomEndpointKey, "")
119121
}
122+
if model.IdentityProviderCustomClientID {
123+
viper.Set(config.IdentityProviderCustomClientIdKey, "")
124+
}
120125

121126
if model.ArgusCustomEndpoint {
122127
viper.Set(config.ArgusCustomEndpointKey, "")
@@ -201,6 +206,7 @@ func configureFlags(cmd *cobra.Command) {
201206

202207
cmd.Flags().Bool(sessionTimeLimitFlag, false, fmt.Sprintf("Maximum time before authentication is required again. If unset, defaults to %s", config.SessionTimeLimitDefault))
203208
cmd.Flags().Bool(identityProviderCustomEndpointFlag, false, "Identity Provider base URL. If unset, uses the default base URL")
209+
cmd.Flags().Bool(identityProviderCustomClientIdFlag, false, "Identity Provider client ID, used for user authentication")
204210

205211
cmd.Flags().Bool(argusCustomEndpointFlag, false, "Argus API base URL. If unset, uses the default base URL")
206212
cmd.Flags().Bool(authorizationCustomEndpointFlag, false, "Authorization API base URL. If unset, uses the default base URL")
@@ -234,6 +240,7 @@ func parseInput(p *print.Printer, cmd *cobra.Command) *inputModel {
234240

235241
SessionTimeLimit: flags.FlagToBoolValue(p, cmd, sessionTimeLimitFlag),
236242
IdentityProviderCustomEndpoint: flags.FlagToBoolValue(p, cmd, identityProviderCustomEndpointFlag),
243+
IdentityProviderCustomClientID: flags.FlagToBoolValue(p, cmd, identityProviderCustomClientIdFlag),
237244

238245
ArgusCustomEndpoint: flags.FlagToBoolValue(p, cmd, argusCustomEndpointFlag),
239246
AuthorizationCustomEndpoint: flags.FlagToBoolValue(p, cmd, authorizationCustomEndpointFlag),

internal/cmd/config/unset/unset_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ func fixtureFlagValues(mods ...func(flagValues map[string]bool)) map[string]bool
1818

1919
sessionTimeLimitFlag: true,
2020
identityProviderCustomEndpointFlag: true,
21+
identityProviderCustomClientIdFlag: true,
2122

2223
argusCustomEndpointFlag: true,
2324
authorizationCustomEndpointFlag: true,
@@ -53,6 +54,7 @@ func fixtureInputModel(mods ...func(model *inputModel)) *inputModel {
5354

5455
SessionTimeLimit: true,
5556
IdentityProviderCustomEndpoint: true,
57+
IdentityProviderCustomClientID: true,
5658

5759
ArgusCustomEndpoint: true,
5860
AuthorizationCustomEndpoint: true,
@@ -104,6 +106,7 @@ func TestParseInput(t *testing.T) {
104106

105107
model.SessionTimeLimit = false
106108
model.IdentityProviderCustomEndpoint = false
109+
model.IdentityProviderCustomClientID = false
107110

108111
model.ArgusCustomEndpoint = false
109112
model.AuthorizationCustomEndpoint = false
@@ -155,6 +158,16 @@ func TestParseInput(t *testing.T) {
155158
model.IdentityProviderCustomEndpoint = false
156159
}),
157160
},
161+
{
162+
description: "identity provider custom client id empty",
163+
flagValues: fixtureFlagValues(func(flagValues map[string]bool) {
164+
flagValues[identityProviderCustomClientIdFlag] = false
165+
}),
166+
isValid: true,
167+
expectedModel: fixtureInputModel(func(model *inputModel) {
168+
model.IdentityProviderCustomClientID = false
169+
}),
170+
},
158171
{
159172
description: "argus custom endpoint empty",
160173
flagValues: fixtureFlagValues(func(flagValues map[string]bool) {

internal/pkg/auth/user_login.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import (
2424

2525
const (
2626
defaultIDPEndpoint = "https://accounts.stackit.cloud/oauth/v2"
27-
cliClientID = "stackit-cli-0000-0000-000000000001"
27+
defaultCLIClientID = "stackit-cli-0000-0000-000000000001"
2828

2929
loginSuccessPath = "/login-successful"
3030
stackitLandingPage = "https://www.stackit.de"
@@ -58,6 +58,18 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
5858
}
5959
}
6060

61+
idpClientID, err := getIDPClientID()
62+
if err != nil {
63+
return err
64+
}
65+
if idpClientID != defaultCLIClientID {
66+
p.Warn("You are using a custom client ID (%s) for authentication.\n", idpClientID)
67+
err := p.PromptForEnter("Press Enter to proceed with the login...")
68+
if err != nil {
69+
return err
70+
}
71+
}
72+
6173
if isReauthentication {
6274
err := p.PromptForEnter("Your session has expired, press Enter to login again...")
6375
if err != nil {
@@ -86,7 +98,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
8698
}
8799

88100
conf := &oauth2.Config{
89-
ClientID: cliClientID,
101+
ClientID: idpClientID,
90102
Endpoint: oauth2.Endpoint{
91103
AuthURL: fmt.Sprintf("%s/authorize", idpEndpoint),
92104
},
@@ -131,7 +143,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
131143
p.Debug(print.DebugLevel, "trading authorization code for access and refresh tokens")
132144

133145
// Trade the authorization code and the code verifier for access and refresh tokens
134-
accessToken, refreshToken, err := getUserAccessAndRefreshTokens(idpEndpoint, cliClientID, codeVerifier, code, redirectURL)
146+
accessToken, refreshToken, err := getUserAccessAndRefreshTokens(idpEndpoint, idpClientID, codeVerifier, code, redirectURL)
135147
if err != nil {
136148
errServer = fmt.Errorf("retrieve tokens: %w", err)
137149
return
@@ -207,6 +219,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
207219

208220
p.Debug(print.DebugLevel, "opening browser for authentication")
209221
p.Debug(print.DebugLevel, "using authentication server on %s", idpEndpoint)
222+
p.Debug(print.DebugLevel, "using client ID %s for authentication ", idpClientID)
210223

211224
// Open a browser window to the authorizationURL
212225
err = openBrowser(authorizationURL)

internal/pkg/auth/user_token_flow.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ func buildRequestToRefreshTokens(utf *userTokenFlow) (*http.Request, error) {
161161
return nil, err
162162
}
163163

164+
idpClientID, err := getIDPClientID()
165+
if err != nil {
166+
return nil, err
167+
}
168+
164169
req, err := http.NewRequest(
165170
http.MethodPost,
166171
fmt.Sprintf("%s/token", idpEndpoint),
@@ -171,7 +176,7 @@ func buildRequestToRefreshTokens(utf *userTokenFlow) (*http.Request, error) {
171176
}
172177
reqQuery := url.Values{}
173178
reqQuery.Set("grant_type", "refresh_token")
174-
reqQuery.Set("client_id", cliClientID)
179+
reqQuery.Set("client_id", idpClientID)
175180
reqQuery.Set("refresh_token", utf.refreshToken)
176181
reqQuery.Set("token_format", "jwt")
177182
req.URL.RawQuery = reqQuery.Encode()

internal/pkg/auth/utils.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,14 @@ func getIDPEndpoint() (string, error) {
2222

2323
return idpEndpoint, nil
2424
}
25+
26+
func getIDPClientID() (string, error) {
27+
idpClientID := defaultCLIClientID
28+
29+
customIDPClientID := viper.GetString(config.IdentityProviderCustomClientIdKey)
30+
if customIDPClientID != "" {
31+
idpClientID = customIDPClientID
32+
}
33+
34+
return idpClientID, nil
35+
}

internal/pkg/auth/utils_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,44 @@ func TestGetIDPEndpoint(t *testing.T) {
5252
})
5353
}
5454
}
55+
56+
func TestGetIDPClientID(t *testing.T) {
57+
tests := []struct {
58+
name string
59+
idpCustomClientID string
60+
isValid bool
61+
expected string
62+
}{
63+
{
64+
name: "custom client ID specified",
65+
idpCustomClientID: "custom-client-id",
66+
isValid: true,
67+
expected: "custom-client-id",
68+
},
69+
{
70+
name: "custom client ID not specified",
71+
idpCustomClientID: "",
72+
isValid: true,
73+
expected: defaultCLIClientID,
74+
},
75+
}
76+
for _, tt := range tests {
77+
t.Run(tt.name, func(t *testing.T) {
78+
viper.Reset()
79+
viper.Set(config.IdentityProviderCustomClientIdKey, tt.idpCustomClientID)
80+
81+
got, err := getIDPClientID()
82+
83+
if tt.isValid && err != nil {
84+
t.Fatalf("expected no error, got %v", err)
85+
}
86+
if !tt.isValid && err == nil {
87+
t.Fatalf("expected error, got none")
88+
}
89+
90+
if got != tt.expected {
91+
t.Fatalf("expected idp client ID %q, got %q", tt.expected, got)
92+
}
93+
})
94+
}
95+
}

internal/pkg/config/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ const (
1818
VerbosityKey = "verbosity"
1919

2020
IdentityProviderCustomEndpointKey = "identity_provider_custom_endpoint"
21+
IdentityProviderCustomClientIdKey = "identity_provider_custom_client_id"
2122

2223
ArgusCustomEndpointKey = "argus_custom_endpoint"
2324
AuthorizationCustomEndpointKey = "authorization_custom_endpoint"
@@ -67,6 +68,7 @@ var ConfigKeys = []string{
6768
VerbosityKey,
6869

6970
IdentityProviderCustomEndpointKey,
71+
IdentityProviderCustomClientIdKey,
7072

7173
DNSCustomEndpointKey,
7274
LoadBalancerCustomEndpointKey,
@@ -148,6 +150,7 @@ func setConfigDefaults() {
148150
viper.SetDefault(ProjectIdKey, "")
149151
viper.SetDefault(SessionTimeLimitKey, SessionTimeLimitDefault)
150152
viper.SetDefault(IdentityProviderCustomEndpointKey, "")
153+
viper.SetDefault(IdentityProviderCustomClientIdKey, "")
151154
viper.SetDefault(DNSCustomEndpointKey, "")
152155
viper.SetDefault(ArgusCustomEndpointKey, "")
153156
viper.SetDefault(AuthorizationCustomEndpointKey, "")

0 commit comments

Comments
 (0)