From f8ee4b1331b7a95047b874642f0324054aa194d3 Mon Sep 17 00:00:00 2001 From: Ashley Davis Date: Thu, 5 Feb 2026 16:27:42 +0000 Subject: [PATCH 1/2] use sigv4 for sending data to s3 Just using sigv4 shouldn't require changes, but the backend made a couple of extra requirements alongside changing to sigv4 - requiring headers for tagging and encryption. Signed-off-by: Ashley Davis --- internal/cyberark/client.go | 4 +- internal/cyberark/client_test.go | 8 +-- internal/cyberark/dataupload/dataupload.go | 65 ++++++++++++------- .../cyberark/dataupload/dataupload_test.go | 11 ++-- .../identity/advance_authentication_test.go | 11 +++- .../identity/authenticated_http_client.go | 14 ++-- .../identity/cmd/testidentity/main.go | 2 +- internal/cyberark/identity/identity.go | 12 +++- internal/cyberark/identity/identity_test.go | 2 +- .../cyberark/servicediscovery/discovery.go | 24 ++++--- .../servicediscovery/discovery_test.go | 2 +- pkg/client/client_cyberark.go | 4 +- 12 files changed, 99 insertions(+), 60 deletions(-) diff --git a/internal/cyberark/client.go b/internal/cyberark/client.go index 3d553142..92710296 100644 --- a/internal/cyberark/client.go +++ b/internal/cyberark/client.go @@ -49,7 +49,7 @@ func LoadClientConfigFromEnvironment() (ClientConfig, error) { // NewDatauploadClient initializes and returns a new CyberArk Data Upload client. // It performs service discovery to find the necessary API endpoints and authenticates // using the provided client configuration. -func NewDatauploadClient(ctx context.Context, httpClient *http.Client, serviceMap *servicediscovery.Services, cfg ClientConfig) (*dataupload.CyberArkClient, error) { +func NewDatauploadClient(ctx context.Context, httpClient *http.Client, serviceMap *servicediscovery.Services, tenantUUID string, cfg ClientConfig) (*dataupload.CyberArkClient, error) { identityAPI := serviceMap.Identity.API if identityAPI == "" { return nil, errors.New("service discovery returned an empty identity API") @@ -67,5 +67,5 @@ func NewDatauploadClient(ctx context.Context, httpClient *http.Client, serviceMa return nil, err } - return dataupload.New(httpClient, discoveryAPI, identityClient.AuthenticateRequest), nil + return dataupload.New(httpClient, discoveryAPI, tenantUUID, identityClient.AuthenticateRequest), nil } diff --git a/internal/cyberark/client_test.go b/internal/cyberark/client_test.go index ae7162ae..1c220d2d 100644 --- a/internal/cyberark/client_test.go +++ b/internal/cyberark/client_test.go @@ -34,12 +34,12 @@ func TestCyberArkClient_PutSnapshot_MockAPI(t *testing.T) { discoveryClient := servicediscovery.New(httpClient) - serviceMap, err := discoveryClient.DiscoverServices(t.Context(), cfg.Subdomain) + serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context(), cfg.Subdomain) if err != nil { t.Fatalf("failed to discover mock services: %v", err) } - cl, err := cyberark.NewDatauploadClient(ctx, httpClient, serviceMap, cfg) + cl, err := cyberark.NewDatauploadClient(ctx, httpClient, serviceMap, tenantUUID, cfg) require.NoError(t, err) err = cl.PutSnapshot(ctx, dataupload.Snapshot{ @@ -78,12 +78,12 @@ func TestCyberArkClient_PutSnapshot_RealAPI(t *testing.T) { discoveryClient := servicediscovery.New(httpClient) - serviceMap, err := discoveryClient.DiscoverServices(t.Context(), cfg.Subdomain) + serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(t.Context(), cfg.Subdomain) if err != nil { t.Fatalf("failed to discover services: %v", err) } - cl, err := cyberark.NewDatauploadClient(ctx, httpClient, serviceMap, cfg) + cl, err := cyberark.NewDatauploadClient(ctx, httpClient, serviceMap, tenantUUID, cfg) require.NoError(t, err) err = cl.PutSnapshot(ctx, dataupload.Snapshot{ diff --git a/internal/cyberark/dataupload/dataupload.go b/internal/cyberark/dataupload/dataupload.go index 0d5bcc08..bf4434f0 100644 --- a/internal/cyberark/dataupload/dataupload.go +++ b/internal/cyberark/dataupload/dataupload.go @@ -15,6 +15,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" arkapi "github.com/jetstack/preflight/internal/cyberark/api" + "github.com/jetstack/preflight/internal/cyberark/identity" "github.com/jetstack/preflight/pkg/version" ) @@ -33,13 +34,19 @@ type CyberArkClient struct { baseURL string httpClient *http.Client - authenticateRequest func(req *http.Request) error + tenantUUID string + + authenticateRequest identity.RequestAuthenticator } -func New(httpClient *http.Client, baseURL string, authenticateRequest func(req *http.Request) error) *CyberArkClient { +// New creates a new CyberArkClient. The tenant UUID is best sourced from service discovery along with the base URL. +func New(httpClient *http.Client, baseURL string, tenantUUID string, authenticateRequest identity.RequestAuthenticator) *CyberArkClient { return &CyberArkClient{ - baseURL: baseURL, - httpClient: httpClient, + baseURL: baseURL, + httpClient: httpClient, + + tenantUUID: tenantUUID, + authenticateRequest: authenticateRequest, } } @@ -102,13 +109,6 @@ type Snapshot struct { // has been received intact. // Read [Checking object integrity for data uploads in Amazon S3](https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity-upload.html), // to learn more. -// -// TODO(wallrj): There is a bug in the AWS backend: -// [S3 Presigned PutObjectCommand URLs ignore Sha256 Hash when uploading](https://github.com/aws/aws-sdk/issues/480) -// ...which means that the `x-amz-checksum-sha256` request header is optional. -// If you omit that header, it is possible to PUT any data. -// There is a work around listed in that issue which we have shared with the -// CyberArk API team. func (c *CyberArkClient) PutSnapshot(ctx context.Context, snapshot Snapshot) error { if snapshot.ClusterID == "" { return fmt.Errorf("programmer mistake: the snapshot cluster ID cannot be left empty") @@ -119,10 +119,12 @@ func (c *CyberArkClient) PutSnapshot(ctx context.Context, snapshot Snapshot) err if err := json.NewEncoder(io.MultiWriter(encodedBody, hash)).Encode(snapshot); err != nil { return err } + checksum := hash.Sum(nil) checksumHex := hex.EncodeToString(checksum) checksumBase64 := base64.StdEncoding.EncodeToString(checksum) - presignedUploadURL, err := c.retrievePresignedUploadURL(ctx, checksumHex, snapshot.ClusterID) + + presignedUploadURL, username, err := c.retrievePresignedUploadURL(ctx, checksumHex, snapshot.ClusterID) if err != nil { return fmt.Errorf("while retrieving snapshot upload URL: %s", err) } @@ -132,7 +134,21 @@ func (c *CyberArkClient) PutSnapshot(ctx context.Context, snapshot Snapshot) err if err != nil { return err } + req.Header.Set("X-Amz-Checksum-Sha256", checksumBase64) + req.Header.Set("X-Amz-Server-Side-Encryption", "AES256") + + q := url.Values{} + + q.Add("agent_version", snapshot.AgentVersion) + q.Add("tenant_id", c.tenantUUID) + q.Add("upload_type", "k8s_snapshot") + q.Add("uploader_id", snapshot.ClusterID) + q.Add("username", username) + q.Add("vendor", "k8s") + + req.Header.Set("X-Amz-Tagging", q.Encode()) + version.SetUserAgent(req) res, err := c.httpClient.Do(req) @@ -152,10 +168,10 @@ func (c *CyberArkClient) PutSnapshot(ctx context.Context, snapshot Snapshot) err return nil } -func (c *CyberArkClient) retrievePresignedUploadURL(ctx context.Context, checksum string, clusterID string) (string, error) { +func (c *CyberArkClient) retrievePresignedUploadURL(ctx context.Context, checksum string, clusterID string) (string, string, error) { uploadURL, err := url.JoinPath(c.baseURL, apiPathSnapshotLinks) if err != nil { - return "", err + return "", "", err } request := struct { @@ -170,18 +186,21 @@ func (c *CyberArkClient) retrievePresignedUploadURL(ctx context.Context, checksu encodedBody := &bytes.Buffer{} if err := json.NewEncoder(encodedBody).Encode(request); err != nil { - return "", err + return "", "", err } req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadURL, encodedBody) if err != nil { - return "", err + return "", "", err } req.Header.Set("Content-Type", "application/json") - if err := c.authenticateRequest(req); err != nil { - return "", fmt.Errorf("failed to authenticate request: %s", err) + + username, err := c.authenticateRequest(req) + if err != nil { + return "", "", fmt.Errorf("failed to authenticate request: %s", err) } + version.SetUserAgent(req) // Add telemetry headers @@ -189,7 +208,7 @@ func (c *CyberArkClient) retrievePresignedUploadURL(ctx context.Context, checksu res, err := c.httpClient.Do(req) if err != nil { - return "", err + return "", "", err } defer res.Body.Close() @@ -198,7 +217,7 @@ func (c *CyberArkClient) retrievePresignedUploadURL(ctx context.Context, checksu if len(body) == 0 { body = []byte(``) } - return "", fmt.Errorf("received response with status code %d: %s", code, bytes.TrimSpace(body)) + return "", "", fmt.Errorf("received response with status code %d: %s", code, bytes.TrimSpace(body)) } response := struct { @@ -207,11 +226,11 @@ func (c *CyberArkClient) retrievePresignedUploadURL(ctx context.Context, checksu if err := json.NewDecoder(io.LimitReader(res.Body, maxRetrievePresignedUploadURLBodySize)).Decode(&response); err != nil { if err == io.ErrUnexpectedEOF { - return "", fmt.Errorf("rejecting JSON response from server as it was too large or was truncated") + return "", "", fmt.Errorf("rejecting JSON response from server as it was too large or was truncated") } - return "", fmt.Errorf("failed to parse JSON from otherwise successful request to start data upload: %s", err) + return "", "", fmt.Errorf("failed to parse JSON from otherwise successful request to start data upload: %s", err) } - return response.URL, nil + return response.URL, username, nil } diff --git a/internal/cyberark/dataupload/dataupload_test.go b/internal/cyberark/dataupload/dataupload_test.go index f38a0e51..d78c4bf3 100644 --- a/internal/cyberark/dataupload/dataupload_test.go +++ b/internal/cyberark/dataupload/dataupload_test.go @@ -10,6 +10,7 @@ import ( "k8s.io/klog/v2/ktesting" "github.com/jetstack/preflight/internal/cyberark/dataupload" + "github.com/jetstack/preflight/internal/cyberark/identity" "github.com/jetstack/preflight/pkg/version" _ "k8s.io/klog/v2/ktesting/init" @@ -19,17 +20,17 @@ import ( // mock API server. The mock server is configured to return different responses // based on the cluster ID and bearer token used in the request. func TestCyberArkClient_PutSnapshot_MockAPI(t *testing.T) { - setToken := func(token string) func(*http.Request) error { - return func(req *http.Request) error { + setToken := func(token string) identity.RequestAuthenticator { + return func(req *http.Request) (string, error) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - return nil + return "foo@example.com", nil // set a dummy username for testing purposes; the actual value is not important for these tests } } tests := []struct { name string snapshot dataupload.Snapshot - authenticate func(req *http.Request) error + authenticate identity.RequestAuthenticator requireFn func(t *testing.T, err error) }{ { @@ -96,7 +97,7 @@ func TestCyberArkClient_PutSnapshot_MockAPI(t *testing.T) { datauploadAPIBaseURL, httpClient := dataupload.MockDataUploadServer(t) - cyberArkClient := dataupload.New(httpClient, datauploadAPIBaseURL, tc.authenticate) + cyberArkClient := dataupload.New(httpClient, datauploadAPIBaseURL, "test-tenant-uuid", tc.authenticate) err := cyberArkClient.PutSnapshot(ctx, tc.snapshot) tc.requireFn(t, err) diff --git a/internal/cyberark/identity/advance_authentication_test.go b/internal/cyberark/identity/advance_authentication_test.go index 9340da30..0c17cd0b 100644 --- a/internal/cyberark/identity/advance_authentication_test.go +++ b/internal/cyberark/identity/advance_authentication_test.go @@ -131,13 +131,18 @@ func Test_IdentityAdvanceAuthentication(t *testing.T) { return } - if len(client.tokenCached) == 0 { + if client.tokenCached.Username != testSpec.username { + t.Errorf("expected username %s to be set on cached token after authentication but got %q", testSpec.username, client.tokenCached.Username) + return + } + + if len(client.tokenCached.Token) == 0 { t.Errorf("expected token for %s to be set to %q but wasn't found", testSpec.username, mockSuccessfulStartAuthenticationToken) return } - if client.tokenCached != mockSuccessfulStartAuthenticationToken { - t.Errorf("expected token for %s to be set to %q but was set to %q", testSpec.username, mockSuccessfulStartAuthenticationToken, client.tokenCached) + if client.tokenCached.Token != mockSuccessfulStartAuthenticationToken { + t.Errorf("expected token for %s to be set to %q but was set to %q", testSpec.username, mockSuccessfulStartAuthenticationToken, client.tokenCached.Token) } }) } diff --git a/internal/cyberark/identity/authenticated_http_client.go b/internal/cyberark/identity/authenticated_http_client.go index 901d14db..c20d5bfb 100644 --- a/internal/cyberark/identity/authenticated_http_client.go +++ b/internal/cyberark/identity/authenticated_http_client.go @@ -5,15 +5,19 @@ import ( "net/http" ) -func (c *Client) AuthenticateRequest(req *http.Request) error { +type RequestAuthenticator func(req *http.Request) (string, error) + +// AuthenticateRequest is a helper function that adds the Authorization header to an HTTP request using a cached token. +// It sets the Header directly, and if successful returns the username corresponding to the token. +func (c *Client) AuthenticateRequest(req *http.Request) (string, error) { c.tokenCachedMutex.Lock() defer c.tokenCachedMutex.Unlock() - if len(c.tokenCached) == 0 { - return fmt.Errorf("no token cached") + if len(c.tokenCached.Token) == 0 { + return "", fmt.Errorf("no token cached") } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(c.tokenCached))) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.tokenCached.Token)) - return nil + return c.tokenCached.Username, nil } diff --git a/internal/cyberark/identity/cmd/testidentity/main.go b/internal/cyberark/identity/cmd/testidentity/main.go index 8729cfbe..916c81ea 100644 --- a/internal/cyberark/identity/cmd/testidentity/main.go +++ b/internal/cyberark/identity/cmd/testidentity/main.go @@ -51,7 +51,7 @@ func run(ctx context.Context) error { httpClient := http_client.NewDefaultClient(version.UserAgent(), rootCAs) sdClient := servicediscovery.New(httpClient) - services, err := sdClient.DiscoverServices(ctx, subdomain) + services, _, err := sdClient.DiscoverServices(ctx, subdomain) if err != nil { return fmt.Errorf("while performing service discovery: %s", err) } diff --git a/internal/cyberark/identity/identity.go b/internal/cyberark/identity/identity.go index e88ba0c1..4f9d6156 100644 --- a/internal/cyberark/identity/identity.go +++ b/internal/cyberark/identity/identity.go @@ -183,7 +183,10 @@ type Client struct { } // token is a wrapper type for holding auth tokens we want to cache. -type token string +type token struct { + Username string + Token string +} // New returns an initialized CyberArk Identity client using a default service discovery client. func New(httpClient *http.Client, baseURL string, subdomain string) *Client { @@ -192,7 +195,7 @@ func New(httpClient *http.Client, baseURL string, subdomain string) *Client { baseURL: baseURL, subdomain: subdomain, - tokenCached: "", + tokenCached: token{}, tokenCachedMutex: sync.Mutex{}, } } @@ -404,7 +407,10 @@ func (c *Client) doAdvanceAuthentication(ctx context.Context, username string, p c.tokenCachedMutex.Lock() - c.tokenCached = token(advanceAuthResponse.Result.Token) + c.tokenCached = token{ + Username: username, + Token: advanceAuthResponse.Result.Token, + } c.tokenCachedMutex.Unlock() diff --git a/internal/cyberark/identity/identity_test.go b/internal/cyberark/identity/identity_test.go index 732805e7..917ba15d 100644 --- a/internal/cyberark/identity/identity_test.go +++ b/internal/cyberark/identity/identity_test.go @@ -53,7 +53,7 @@ func TestLoginUsernamePassword_RealAPI(t *testing.T) { arktesting.SkipIfNoEnv(t) subdomain := os.Getenv("ARK_SUBDOMAIN") httpClient := http.DefaultClient - services, err := servicediscovery.New(httpClient).DiscoverServices(t.Context(), subdomain) + services, _, err := servicediscovery.New(httpClient).DiscoverServices(t.Context(), subdomain) require.NoError(t, err) loginUsernamePasswordTests(t, func(t testing.TB) inputs { diff --git a/internal/cyberark/servicediscovery/discovery.go b/internal/cyberark/servicediscovery/discovery.go index e838e507..82394ab3 100644 --- a/internal/cyberark/servicediscovery/discovery.go +++ b/internal/cyberark/servicediscovery/discovery.go @@ -95,17 +95,21 @@ type Services struct { // DiscoverServices fetches from the service discovery service for a given subdomain // and parses the CyberArk Identity API URL and Inventory API URL. -func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Services, error) { +// It also returns the Tenant ID UUID corresponding to the subdomain. +func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Services, string, error) { u, err := url.Parse(c.baseURL) if err != nil { - return nil, fmt.Errorf("invalid base URL for service discovery: %w", err) + return nil, "", fmt.Errorf("invalid base URL for service discovery: %w", err) } + u.Path = path.Join(u.Path, "api/public/tenant-discovery") u.RawQuery = url.Values{"bySubdomain": []string{subdomain}}.Encode() + endpoint := u.String() + request, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { - return nil, fmt.Errorf("failed to initialise request to %s: %s", endpoint, err) + return nil, "", fmt.Errorf("failed to initialise request to %s: %s", endpoint, err) } request.Header.Set("Accept", "application/json") @@ -114,7 +118,7 @@ func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Servi arkapi.SetTelemetryRequestHeader(request) resp, err := c.client.Do(request) if err != nil { - return nil, fmt.Errorf("failed to perform HTTP request: %s", err) + return nil, "", fmt.Errorf("failed to perform HTTP request: %s", err) } defer resp.Body.Close() @@ -123,19 +127,19 @@ func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Servi // a 404 error is returned with an empty JSON body "{}" if the subdomain is unknown; at the time of writing, we haven't observed // any other errors and so we can't special case them if resp.StatusCode == http.StatusNotFound { - return nil, fmt.Errorf("got an HTTP 404 response from service discovery; maybe the subdomain %q is incorrect or does not exist?", subdomain) + return nil, "", fmt.Errorf("got an HTTP 404 response from service discovery; maybe the subdomain %q is incorrect or does not exist?", subdomain) } - return nil, fmt.Errorf("got unexpected status code %s from request to service discovery API", resp.Status) + return nil, "", fmt.Errorf("got unexpected status code %s from request to service discovery API", resp.Status) } var discoveryResp DiscoveryResponse err = json.NewDecoder(io.LimitReader(resp.Body, maxDiscoverBodySize)).Decode(&discoveryResp) if err != nil { if err == io.ErrUnexpectedEOF { - return nil, fmt.Errorf("rejecting JSON response from server as it was too large or was truncated") + return nil, "", fmt.Errorf("rejecting JSON response from server as it was too large or was truncated") } - return nil, fmt.Errorf("failed to parse JSON from otherwise successful request to service discovery endpoint: %s", err) + return nil, "", fmt.Errorf("failed to parse JSON from otherwise successful request to service discovery endpoint: %s", err) } var identityAPI, discoveryContextAPI string for _, svc := range discoveryResp.Services { @@ -158,7 +162,7 @@ func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Servi } if identityAPI == "" { - return nil, fmt.Errorf("didn't find %s in service discovery response, "+ + return nil, "", fmt.Errorf("didn't find %s in service discovery response, "+ "which may indicate a suspended tenant; unable to detect CyberArk Identity API URL", IdentityServiceName) } //TODO: Should add a check for discoveryContextAPI too? @@ -166,5 +170,5 @@ func (c *Client) DiscoverServices(ctx context.Context, subdomain string) (*Servi return &Services{ Identity: ServiceEndpoint{API: identityAPI}, DiscoveryContext: ServiceEndpoint{API: discoveryContextAPI}, - }, nil + }, discoveryResp.TenantID, nil } diff --git a/internal/cyberark/servicediscovery/discovery_test.go b/internal/cyberark/servicediscovery/discovery_test.go index d1091307..00d0fd58 100644 --- a/internal/cyberark/servicediscovery/discovery_test.go +++ b/internal/cyberark/servicediscovery/discovery_test.go @@ -66,7 +66,7 @@ func Test_DiscoverIdentityAPIURL(t *testing.T) { client := New(httpClient) - services, err := client.DiscoverServices(ctx, testSpec.subdomain) + services, _, err := client.DiscoverServices(ctx, testSpec.subdomain) if testSpec.expectedError != nil { assert.EqualError(t, err, testSpec.expectedError.Error()) assert.Nil(t, services) diff --git a/pkg/client/client_cyberark.go b/pkg/client/client_cyberark.go index 735313bd..3c573689 100644 --- a/pkg/client/client_cyberark.go +++ b/pkg/client/client_cyberark.go @@ -67,7 +67,7 @@ func (o *CyberArkClient) PostDataReadingsWithOptions(ctx context.Context, readin discoveryClient := servicediscovery.New(o.httpClient) - serviceMap, err := discoveryClient.DiscoverServices(ctx, cfg.Subdomain) + serviceMap, tenantUUID, err := discoveryClient.DiscoverServices(ctx, cfg.Subdomain) if err != nil { return err } @@ -81,7 +81,7 @@ func (o *CyberArkClient) PostDataReadingsWithOptions(ctx context.Context, readin // Minimize the snapshot to reduce size and improve privacy minimizeSnapshot(log.V(logs.Debug), &snapshot) - datauploadClient, err := cyberark.NewDatauploadClient(ctx, o.httpClient, serviceMap, cfg) + datauploadClient, err := cyberark.NewDatauploadClient(ctx, o.httpClient, serviceMap, tenantUUID, cfg) if err != nil { return fmt.Errorf("while initializing data upload client: %s", err) } From 298d84fedb597b2f4bff374f47e6f7fd38b7ea13 Mon Sep 17 00:00:00 2001 From: Ashley Davis Date: Wed, 11 Feb 2026 14:56:50 +0000 Subject: [PATCH 2/2] add file size sending and improve tests for sigv4 Signed-off-by: Ashley Davis --- internal/cyberark/dataupload/dataupload.go | 24 ++- internal/cyberark/dataupload/mock.go | 167 +++++++++++++++--- .../testdata/discovery_success.json.template | 1 + 3 files changed, 156 insertions(+), 36 deletions(-) diff --git a/internal/cyberark/dataupload/dataupload.go b/internal/cyberark/dataupload/dataupload.go index bf4434f0..18fba38e 100644 --- a/internal/cyberark/dataupload/dataupload.go +++ b/internal/cyberark/dataupload/dataupload.go @@ -124,7 +124,7 @@ func (c *CyberArkClient) PutSnapshot(ctx context.Context, snapshot Snapshot) err checksumHex := hex.EncodeToString(checksum) checksumBase64 := base64.StdEncoding.EncodeToString(checksum) - presignedUploadURL, username, err := c.retrievePresignedUploadURL(ctx, checksumHex, snapshot.ClusterID) + presignedUploadURL, username, err := c.retrievePresignedUploadURL(ctx, checksumHex, snapshot.ClusterID, int64(encodedBody.Len())) if err != nil { return fmt.Errorf("while retrieving snapshot upload URL: %s", err) } @@ -168,20 +168,30 @@ func (c *CyberArkClient) PutSnapshot(ctx context.Context, snapshot Snapshot) err return nil } -func (c *CyberArkClient) retrievePresignedUploadURL(ctx context.Context, checksum string, clusterID string) (string, string, error) { +// RetrievePresignedUploadURLRequest is the JSON body sent to the inventory API to request a presigned upload URL. +type RetrievePresignedUploadURLRequest struct { + ClusterID string `json:"cluster_id"` + Checksum string `json:"checksum_sha256"` + + // AgentVersion is the v-prefixed version of the agent uploading the snapshot. + // Note that the backend relies on this version being v-prefixed semver. + AgentVersion string `json:"agent_version"` + + // FileSize is the size of the data we'll upload in bytes + FileSize int64 `json:"file_size"` +} + +func (c *CyberArkClient) retrievePresignedUploadURL(ctx context.Context, checksum string, clusterID string, fileSize int64) (string, string, error) { uploadURL, err := url.JoinPath(c.baseURL, apiPathSnapshotLinks) if err != nil { return "", "", err } - request := struct { - ClusterID string `json:"cluster_id"` - Checksum string `json:"checksum_sha256"` - AgentVersion string `json:"agent_version"` - }{ + request := RetrievePresignedUploadURLRequest{ ClusterID: clusterID, Checksum: checksum, AgentVersion: version.PreflightVersion, + FileSize: fileSize, } encodedBody := &bytes.Buffer{} diff --git a/internal/cyberark/dataupload/mock.go b/internal/cyberark/dataupload/mock.go index 80daf395..1aca88db 100644 --- a/internal/cyberark/dataupload/mock.go +++ b/internal/cyberark/dataupload/mock.go @@ -2,13 +2,17 @@ package dataupload import ( "bytes" + "crypto/rand" "crypto/sha256" "encoding/base64" + "encoding/hex" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" + "net/url" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -25,9 +29,19 @@ const ( successClusterID = "ffffffff-ffff-ffff-ffff-ffffffffffff" ) +type uploadValues struct { + ClusterID string + FileSize int64 +} + type mockDataUploadServer struct { t testing.TB serverURL string + + mux *http.ServeMux + + expectedUploadValues map[string]uploadValues + expectedUploadValuesMutex sync.Mutex } // MockDataUploadServer starts a server which mocks the CyberArk @@ -45,13 +59,24 @@ type mockDataUploadServer struct { // responses. func MockDataUploadServer(t testing.TB) (string, *http.Client) { mux := http.NewServeMux() - server := httptest.NewTLSServer(mux) - t.Cleanup(server.Close) mds := &mockDataUploadServer{ - t: t, - serverURL: server.URL, + t: t, + + expectedUploadValues: make(map[string]uploadValues), } - mux.Handle("/", mds) + + mux.HandleFunc("POST "+apiPathSnapshotLinks, mds.handleSnapshotLinks) + + // The path includes random data to ensure that each request is treated separately by the mock server, allowing us to track data across calls. + // It also ensures that the client isn't using some pre-saved path and is actually using the presigned URL returned by the mock server in the previous step, which is important for test validity. + mux.HandleFunc("PUT /presigned-upload/{randData}", mds.handlePresignedUpload) + + server := httptest.NewTLSServer(mds) + t.Cleanup(server.Close) + + mds.mux = mux + mds.serverURL = server.URL + httpClient := server.Client() httpClient.Transport = transport.NewDebuggingRoundTripper(httpClient.Transport, transport.DebugByContext) return server.URL, httpClient @@ -59,25 +84,23 @@ func MockDataUploadServer(t testing.TB) (string, *http.Client) { func (mds *mockDataUploadServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { mds.t.Log(r.Method, r.RequestURI) - switch r.URL.Path { - case apiPathSnapshotLinks: - mds.handleSnapshotLinks(w, r) - return - case "/presigned-upload": - mds.handlePresignedUpload(w, r) - return - default: - w.WriteHeader(http.StatusNotFound) - } + + mds.mux.ServeHTTP(w, r) } -func (mds *mockDataUploadServer) handleSnapshotLinks(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - w.WriteHeader(http.StatusMethodNotAllowed) - _, _ = w.Write([]byte(`{"message":"method not allowed"}`)) - return +// randHex reads 8 random bytes and returns them as a hex string. It is used to generate +// unique paths per-request to ensure that file size is tracked across calls. +func randHex() string { + b := make([]byte, 8) + _, err := rand.Read(b) + if err != nil { + panic("failed to read random bytes: " + err.Error()) } + return hex.EncodeToString(b) +} + +func (mds *mockDataUploadServer) handleSnapshotLinks(w http.ResponseWriter, r *http.Request) { if r.Header.Get("User-Agent") != version.UserAgent() { http.Error(w, "should set user agent on all requests", http.StatusInternalServerError) return @@ -99,13 +122,11 @@ func (mds *mockDataUploadServer) handleSnapshotLinks(w http.ResponseWriter, r *h return } + var req RetrievePresignedUploadURLRequest + decoder := json.NewDecoder(r.Body) - var req struct { - ClusterID string `json:"cluster_id"` - Checksum string `json:"checksum_sha256"` - AgentVersion string `json:"agent_version"` - } decoder.DisallowUnknownFields() + if err := decoder.Decode(&req); err != nil { http.Error(w, `{"error": "Invalid request format"}`, http.StatusBadRequest) return @@ -135,10 +156,33 @@ func (mds *mockDataUploadServer) handleSnapshotLinks(w http.ResponseWriter, r *h return } + if req.FileSize <= 0 { + http.Error(w, "file size must be greater than 0", http.StatusInternalServerError) + return + } + + randomData := randHex() + + mds.expectedUploadValuesMutex.Lock() + defer mds.expectedUploadValuesMutex.Unlock() + + uploadValues := uploadValues{ + ClusterID: req.ClusterID, + FileSize: req.FileSize, + } + + mds.expectedUploadValues[randomData] = uploadValues + + presignedURL, err := url.JoinPath(mds.serverURL, "presigned-upload", randomData) + if err != nil { + http.Error(w, "failed to generate presigned URL", http.StatusInternalServerError) + mds.t.Logf("failed to generate presigned URL: %v", err) + return + } + // Write response body w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") - presignedURL := mds.serverURL + "/presigned-upload" _ = json.NewEncoder(w).Encode(struct { URL string `json:"url"` }{presignedURL}) @@ -155,9 +199,18 @@ const amzExampleChecksumError = ` ` func (mds *mockDataUploadServer) handlePresignedUpload(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPut { - w.WriteHeader(http.StatusMethodNotAllowed) - _, _ = w.Write([]byte(`{"message":"method not allowed"}`)) + randData := r.PathValue("randData") + if randData == "" { + http.Error(w, "missing randData in path; should match that returned in presigned url", http.StatusInternalServerError) + return + } + + mds.expectedUploadValuesMutex.Lock() + uploadValues, ok := mds.expectedUploadValues[randData] + mds.expectedUploadValuesMutex.Unlock() + + if !ok { + http.Error(w, "didn't find a prior call to generate presigned URL", http.StatusInternalServerError) return } @@ -178,9 +231,65 @@ func (mds *mockDataUploadServer) handlePresignedUpload(w http.ResponseWriter, r return } + sseHeader := r.Header.Get("X-Amz-Server-Side-Encryption") + if sseHeader != "AES256" { + http.Error(w, "should set x-amz-server-side-encryption header to AES256 on all requests", http.StatusInternalServerError) + return + } + + taggingHeader := r.Header.Get("X-Amz-Tagging") + if taggingHeader == "" { + http.Error(w, "should set x-amz-tagging header on all requests", http.StatusInternalServerError) + return + } + + tags, err := url.ParseQuery(taggingHeader) + if err != nil { + http.Error(w, "x-amz-tagging header should be encoded as a valid query string", http.StatusInternalServerError) + return + } + + if tags.Get("agent_version") != version.PreflightVersion { + http.Error(w, fmt.Sprintf("x-amz-tagging should contain an agent_version tag with value %s", version.PreflightVersion), http.StatusInternalServerError) + return + } + + if tags.Get("tenant_id") == "" { + // TODO: if we change setup a bit, we can check the tenant_id matches the expected tenant_id from the test config, but for now, just check it's set + http.Error(w, "x-amz-tagging should contain a tenant_id tag", http.StatusInternalServerError) + return + } + + if tags.Get("upload_type") != "k8s_snapshot" { + http.Error(w, "x-amz-tagging should contain an upload_type tag with value k8s_snapshot", http.StatusInternalServerError) + return + } + + if tags.Get("uploader_id") != uploadValues.ClusterID { + http.Error(w, "x-amz-tagging should contain an uploader_id tag which matches the cluster ID sent in the RetrievePresignedUploadURL request", http.StatusInternalServerError) + return + } + + if tags.Get("username") == "" { + // TODO: if we change setup a bit, we can check the username matches the expected username from the test config + // but for now, just check it's set + http.Error(w, "x-amz-tagging should contain a username tag", http.StatusInternalServerError) + return + } + + if tags.Get("vendor") != "k8s" { + http.Error(w, "x-amz-tagging should contain a vendor tag with value k8s", http.StatusInternalServerError) + return + } + body, err := io.ReadAll(r.Body) require.NoError(mds.t, err) + if uploadValues.FileSize != int64(len(body)) { + http.Error(w, fmt.Sprintf("file size in request body should match that sent in RetrievePresignedUploadURL request; expected %d, got %d", uploadValues.FileSize, len(body)), http.StatusInternalServerError) + return + } + hash := sha256.New() _, err = hash.Write(body) require.NoError(mds.t, err) diff --git a/internal/cyberark/servicediscovery/testdata/discovery_success.json.template b/internal/cyberark/servicediscovery/testdata/discovery_success.json.template index c503028c..ee04b067 100644 --- a/internal/cyberark/servicediscovery/testdata/discovery_success.json.template +++ b/internal/cyberark/servicediscovery/testdata/discovery_success.json.template @@ -3,6 +3,7 @@ "dr_region": "us-east-2", "subdomain": "venafi-test", "platform_id": "platform-123", + "tenant_id": "tenant-123", "identity_id": "identity-456", "default_url": "https://venafi-test.integration-cyberark.cloud", "tenant_flags": {