diff --git a/.terraform.lock.hcl b/.terraform.lock.hcl index 4b829f52..a462d8d9 100644 --- a/.terraform.lock.hcl +++ b/.terraform.lock.hcl @@ -2,37 +2,37 @@ # Manual edits may be lost in future updates. provider "registry.terraform.io/hashicorp/aws" { - version = "6.30.0" + version = "6.31.0" constraints = ">= 4.56.0" hashes = [ - "h1:61K3makVG+zqd7eePXPsAFpQZN33Z28Kf9g+OLv/JYM=", - "h1:Bao0MYQHdNQzavGviQLUdR3A1u2NOv6OIc2V5I+VyuY=", - "h1:FNkicntiPhllPhKf8uBJTCQVY/cqN/sXa/LwE4Q0ML8=", - "h1:NHCJ8SQ71K+p5YKety3SY4PvZ0MfIv92c8blfVf4QP0=", - "h1:NNwip4EfMdYRY+fLkSQHekZS67EFeKJmr5Lmu80ajlI=", - "h1:XPXgScQHM4sfRlz7jRNek2/2hj2ZyLhZpaFVtLrzT7o=", - "h1:bpwk17AlZ00qP8tsBna6RUh4qngv+P3VQY26eNtDO4s=", - "h1:iT9e39SGzBWyq7gcmNkBZrWA6vMGJoYMS4CCKHIclqA=", - "h1:ilAhYTs7SG2u59KjKmbIwZq+DAcV7s3cirbjJAMX+ZM=", - "h1:kmLcNCYh5eYzvS+RWD0dxf0qk1u1Ix/8fNkVhSXEciE=", - "h1:qApi394T9DCGHdUsB1JMzZD06uUjyDBI4XYAjBafOY4=", - "h1:rDgnw9NMNfSRUQiv0YZEGWDBqk46RvlBNJt4zOIijMY=", - "h1:u3SrPueECINoUjAy9ix3MG0SaXcfUSghQCMOnnk0FCM=", - "h1:weYTFOITWwcJ7d3/FWWElAYhWcDfyUI19WTct4fdOmg=", - "zh:08fdcbb84b63739b758fd2f657303f495859ae15f2d6c3dbd642520cadb5f063", - "zh:1e69ff49906541cd511bdabcd4b2996a731b1642ba26b834cdac5432e8d5c557", - "zh:3aa23e3af1fb1dd0c025cb8fb73abdabd3f44b6a687a2a239947e7b0201b2f1f", - "zh:4b3b81e63eee913c874e8115d6a83d12bd9d7903446f91be15ba50c583c79549", - "zh:6e93a72d8770d73a4122dc82af33a020d58feeaca4e194a2685dce30dbcdce24", - "zh:74be722c9a64b95e06554cde0bef624084cc5a5ea7f3373f1975b7a4737d7074", - "zh:7d2acf6bc93be26504fd0e2965c77699a49549f74a767d0a81430d9e12d51358", + "h1:1s5nzQMTgvB1AR2yUZcLz7hDM4PdEZZkKdEnBKaq8AY=", + "h1:8ylMF1VIXmf5Dc2Ew5BGOx+T1jblHB0Ik6nzxWxuu2Y=", + "h1:KYem6P5vF6BD+OpuZxrD2Vj668i54BS7+x4pgu0U6DY=", + "h1:M7p95xWoE3zWo1hx0aGg7GB262c7OizE+h6ISrzl1So=", + "h1:NPMzGvaqqewWvKA5ButStYnFWbgFIExWNKuuHZug1sE=", + "h1:PWZAzhvwJ7ccFn9c7EPuCrmGZHk+4w1cqeNA0DDGh/I=", + "h1:QSBCfLTHwJ4iMMFXEwGnc0hJ4pVZ94cenZYLXjnLJx8=", + "h1:bpfP3nGNyZEDeITqMfP/RSZ7aNsdskoK1aDj1t3xWQw=", + "h1:cptVA3yG3sJCAFVcTAtdoreGvvNUqskDYWkGSe/toB4=", + "h1:frOihN9Tg4JBTrN5eBhWmMcR4chDM+YJazvwgXQSHFU=", + "h1:mbLgY4xcgqOsFbkWMzOHa478thZM5SX1ci0F7Isf+7k=", + "h1:pQvUqZS7sRQzYj06cmDoRL/AkRgGo16laRl2nDQC0hQ=", + "h1:uOBPLestaUsogblJEzezIuPD3yi2VqAxjMn3usJ7i5g=", + "h1:yrswE/HFLKpPvs2622EyUPWv87ir0OuixxhrESWTt4g=", + "zh:0184b83f61dfb2f90f051d6a10e85d554809eb7dec13c49000bc884cfd1e956d", + "zh:16f76019ad67f0dfafea2c65b17bd1aa289cb5c275521df71337e23b08af6fec", + "zh:296ebaa261729b78159694e3ca709735c5c67913d6107c7e1abd4d1e9b05fc6b", + "zh:6b4c37bd7e8abca1b428903212de731b04695dcc59e2ba2acefc3d936b36c4dc", + "zh:6f49e2f7464dbb9d6911dd32951637f589d21a5c3c9a7c5056837701977ec803", + "zh:6fc4095e59286dd83e9528346390b0b07b3bffa1d46b50027c9e080352207626", + "zh:98816c0c5d1b956b564c2d2feb423fdf4eb3e476a6c5202668a285ff3b2d6910", "zh:9b12af85486a96aedd8d7984b0ff811a4b42e3d88dad1a3fb4c0b580d04fa425", - "zh:aef629bc537b4cc0f64ece87bc2bfdb3e032a4d03a3f7f301f4c84ffdc2ac1ac", - "zh:b41dcc4a2c8e356d82d3f92629aab0e25849db106a43e7adf06d8c6bda7af4c9", - "zh:b4d7a9cf9ad5ac5dd07f4ea1e834b63f14e752f9aca9452cd99570fed16e0c12", - "zh:bcb20f64b9b4599fa746305bcff7eeee3da85029dc467f812f950cf45b519436", - "zh:e45a520b82a1d2d42360db1b93d8e96406a7548948ed528bac5018e1d731c5c6", - "zh:f743e4a0e10dc64669469e6a22e47012f07fb94587f5a1e8cf5431da4e878ae1", - "zh:fe1895af7dcc5815896f892b2593fe71b7f4f364b71d9487d6e8b10ef244c11c", + "zh:a70b34fb8d5a7d3b3823046938f3c9b0527afa93d02086b3d87ffa668c9a350e", + "zh:c24c0a58a8301d13cb4c27738840c8e7e0f29563ccf8d6b5ca1c87fcf21bdf89", + "zh:c95d44b2baea56b03198acaaf50f9196504d0207118e1afca7d70b5840315dc4", + "zh:db5a7692e2bde721a37b83f89eb9886715dbd17eb45858b1b58b8f7903ce5144", + "zh:db706b23a652e06c6c3f5de1da70e55a20a4fc2f73c01c24f7b9cd39a9a35f56", + "zh:fb781119fa98d8b0318ffb26ef013d5e674637e8f6a36b4b9c2742f24c022538", + "zh:fc459573b260a5a295d5fed442bf4f44fe034a0702ca22a65d320bd3b3e70eb5", ] } diff --git a/aws-source/adapters/adapterhelpers_always_get_source.go b/aws-source/adapters/adapterhelpers_always_get_source.go index ef69fe26..85c730ee 100644 --- a/aws-source/adapters/adapterhelpers_always_get_source.go +++ b/aws-source/adapters/adapterhelpers_always_get_source.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "sync" + "sync/atomic" "time" "buf.build/go/protovalidate" @@ -80,7 +80,7 @@ type AlwaysGetAdapter[ListInput InputType, ListOutput OutputType, GetInput Input ListFuncOutputMapper func(output ListOutput, input ListInput) ([]GetInput, error) CacheDuration time.Duration // How long to cache items for - cache sdpcache.Cache // The cache for this adapter (set during creation, can be nil for tests) + cache sdpcache.Cache // This is mandatory } func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) cacheDuration() time.Duration { @@ -91,21 +91,6 @@ func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStru return s.CacheDuration } -var ( - noOpCacheAlwaysGetOnce sync.Once - noOpCacheAlwaysGet sdpcache.Cache -) - -func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) Cache() sdpcache.Cache { - if s.cache == nil { - noOpCacheAlwaysGetOnce.Do(func() { - noOpCacheAlwaysGet = sdpcache.NewNoOpCache() - }) - return noOpCacheAlwaysGet - } - return s.cache -} - // Validate Checks that the adapter has been set up correctly func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) Validate() error { if !s.DisableList { @@ -168,7 +153,7 @@ func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStru return nil, WrapAWSError(err) } - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.ItemType, query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.ItemType, query, ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -187,12 +172,12 @@ func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStru if err != nil { err := WrapAWSError(err) if !CanRetry(err) { - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) } return nil, err } - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) return item, nil } @@ -218,9 +203,13 @@ func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStru return } - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) defer done() if qErr != nil { + // For better semantics, convert cached NOTFOUND into empty result + if qErr.GetErrorType() == sdp.QueryError_NOTFOUND { + return + } stream.SendError(qErr) return } @@ -238,20 +227,42 @@ func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStru paginator := s.ListFuncPaginatorBuilder(s.Client, input) var newGetInputs []GetInput p := pool.New().WithContext(ctx).WithMaxGoroutines(s.MaxParallel.Value()) + + // Track whether any items were found and if we had an error + var itemsSent atomic.Int64 + var hadError atomic.Bool + defer func() { // Always wait for everything to be completed before returning err := p.Wait() if err != nil { sentry.CaptureException(err) } + + // Only cache not-found when no items were found AND no error occurred + // If we had an error, that error is already cached, don't overwrite it + shouldCacheNotFound := itemsSent.Load() == 0 && !hadError.Load() + + if shouldCacheNotFound { + notFoundErr := &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: fmt.Sprintf("no %s found in scope %s", s.ItemType, scope), + Scope: scope, + SourceName: s.Name(), + ItemType: s.ItemType, + ResponderName: s.Name(), + } + s.cache.StoreError(ctx, notFoundErr, s.cacheDuration(), ck) + } }() for paginator.HasMorePages() { output, err := paginator.NextPage(ctx) if err != nil { + hadError.Store(true) err := WrapAWSError(err) if !CanRetry(err) { - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) } stream.SendError(err) return @@ -259,9 +270,10 @@ func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStru newGetInputs, err = s.ListFuncOutputMapper(output, input) if err != nil { + hadError.Store(true) err := WrapAWSError(err) if !CanRetry(err) { - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) } stream.SendError(err) return @@ -276,10 +288,13 @@ func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStru if err != nil { // Don't cache individual errors as they are cheap to re-run stream.SendError(WrapAWSError(err)) + // Mark that we had an error so we don't cache NOTFOUND + hadError.Store(true) } if item != nil { - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) stream.SendItem(item) + itemsSent.Add(1) } return nil @@ -322,9 +337,13 @@ func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStru // SearchCustom Searches using custom mapping logic. The SearchInputMapper is // used to create an input for ListFunc, at which point the usual logic is used func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStruct, Options]) SearchCustom(ctx context.Context, scope string, query string, ignoreCache bool, stream discovery.QueryResultStream) { - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_SEARCH, scope, s.ItemType, query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_SEARCH, scope, s.ItemType, query, ignoreCache) defer done() if qErr != nil { + // For better semantics, convert cached NOTFOUND into empty result + if qErr.GetErrorType() == sdp.QueryError_NOTFOUND { + return + } stream.SendError(qErr) return } @@ -356,15 +375,26 @@ func (s *AlwaysGetAdapter[ListInput, ListOutput, GetInput, GetOutput, ClientStru if err != nil { err := WrapAWSError(err) if !CanRetry(err) { - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) } stream.SendError(err) return } if item != nil { - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) stream.SendItem(item) + } else { + // Cache not-found when item is nil + notFoundErr := &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: fmt.Sprintf("%s not found for search query '%s'", s.ItemType, query), + Scope: scope, + SourceName: s.Name(), + ItemType: s.ItemType, + ResponderName: s.Name(), + } + s.cache.StoreError(ctx, notFoundErr, s.cacheDuration(), ck) } } else { stream.SendError(errors.New("SearchCustom called without SearchInputMapper or SearchGetInputMapper")) diff --git a/aws-source/adapters/adapterhelpers_always_get_source_test.go b/aws-source/adapters/adapterhelpers_always_get_source_test.go index 880e82ed..66d1b10b 100644 --- a/aws-source/adapters/adapterhelpers_always_get_source_test.go +++ b/aws-source/adapters/adapterhelpers_always_get_source_test.go @@ -76,6 +76,7 @@ func TestAlwaysGetSourceGet(t *testing.T) { GetInputMapper: func(scope, query string) string { return "" }, + cache: sdpcache.NewNoOpCache(), } _, err := lgs.Get(context.Background(), "foo.bar", "", false) @@ -108,6 +109,7 @@ func TestAlwaysGetSourceGet(t *testing.T) { GetInputMapper: func(scope, query string) string { return "" }, + cache: sdpcache.NewNoOpCache(), } _, err := lgs.Get(context.Background(), "foo.bar", "", false) @@ -144,6 +146,7 @@ func TestAlwaysGetSourceList(t *testing.T) { GetInputMapper: func(scope, query string) string { return "" }, + cache: sdpcache.NewNoOpCache(), } stream := discovery.NewRecordingQueryResultStream() @@ -183,6 +186,7 @@ func TestAlwaysGetSourceList(t *testing.T) { GetInputMapper: func(scope, query string) string { return "" }, + cache: sdpcache.NewNoOpCache(), } stream := discovery.NewRecordingQueryResultStream() @@ -228,6 +232,7 @@ func TestAlwaysGetSourceList(t *testing.T) { GetInputMapper: func(scope, query string) string { return "" }, + cache: sdpcache.NewNoOpCache(), } stream := discovery.NewRecordingQueryResultStream() @@ -275,6 +280,7 @@ func TestAlwaysGetSourceSearch(t *testing.T) { GetInputMapper: func(scope, query string) string { return scope + "." + query }, + cache: sdpcache.NewNoOpCache(), } t.Run("bad ARN", func(t *testing.T) { @@ -338,6 +344,7 @@ func TestAlwaysGetSourceSearch(t *testing.T) { GetInputMapper: func(scope, query string) string { return scope + "." + query }, + cache: sdpcache.NewNoOpCache(), } t.Run("ARN", func(t *testing.T) { @@ -400,6 +407,7 @@ func TestAlwaysGetSourceSearch(t *testing.T) { GetInputMapper: func(scope, query string) string { return "" }, + cache: sdpcache.NewNoOpCache(), } stream := discovery.NewRecordingQueryResultStream() @@ -448,6 +456,7 @@ func TestAlwaysGetSourceSearch(t *testing.T) { GetInputMapper: func(scope, query string) string { return scope + "." + query }, + cache: sdpcache.NewNoOpCache(), } stream := discovery.NewRecordingQueryResultStream() @@ -475,7 +484,7 @@ func TestAlwaysGetSourceCaching(t *testing.T) { Region: "eu-west-2", Client: struct{}{}, ListInput: "", - cache: sdpcache.NewCache(ctx), + cache: sdpcache.NewMemoryCache(), ListFuncPaginatorBuilder: func(client struct{}, input string) Paginator[string, struct{}] { return &TestPaginator{ DataFunc: func() string { diff --git a/aws-source/adapters/adapterhelpers_describe_source.go b/aws-source/adapters/adapterhelpers_describe_source.go index a08186e6..a5826ec6 100644 --- a/aws-source/adapters/adapterhelpers_describe_source.go +++ b/aws-source/adapters/adapterhelpers_describe_source.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "strings" - "sync" "time" "buf.build/go/protovalidate" @@ -29,7 +28,7 @@ type DescribeOnlyAdapter[Input InputType, Output OutputType, ClientStruct Client AdapterMetadata *sdp.AdapterMetadata CacheDuration time.Duration // How long to cache items for - cache sdpcache.Cache // The cache for this adapter (set during creation, can be nil for tests) + cache sdpcache.Cache // The cache for this adapter (set during creation, can be nil for tests) // The function that should be used to describe the resources that this // adapter is related to @@ -108,21 +107,6 @@ func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) cacheDuratio return s.CacheDuration } -var ( - noOpCacheDescribeOnce sync.Once - noOpCacheDescribe sdpcache.Cache -) - -func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) Cache() sdpcache.Cache { - if s.cache == nil { - noOpCacheDescribeOnce.Do(func() { - noOpCacheDescribe = sdpcache.NewNoOpCache() - }) - return noOpCacheDescribe - } - return s.cache -} - // Validate Checks that the adapter is correctly set up and returns an error if // not func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) Validate() error { @@ -193,7 +177,7 @@ func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) Get(ctx cont return nil, WrapAWSError(err) } - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.ItemType, query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.ItemType, query, ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -257,7 +241,7 @@ func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) Get(ctx cont ItemType: s.ItemType, ResponderName: s.Name(), } - s.Cache().StoreError(ctx, qErr, s.cacheDuration(), ck) + s.cache.StoreError(ctx, qErr, s.cacheDuration(), ck) return nil, qErr case numItems == 0: @@ -269,11 +253,11 @@ func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) Get(ctx cont ItemType: s.ItemType, ResponderName: s.Name(), } - s.Cache().StoreError(ctx, qErr, s.cacheDuration(), ck) + s.cache.StoreError(ctx, qErr, s.cacheDuration(), ck) return nil, qErr } - s.Cache().StoreItem(ctx, items[0], s.cacheDuration(), ck) + s.cache.StoreItem(ctx, items[0], s.cacheDuration(), ck) return items[0], nil } @@ -309,9 +293,13 @@ func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) ListStream(c return } - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) defer done() if qErr != nil { + // For better semantics, convert cached NOTFOUND into empty result + if qErr.GetErrorType() == sdp.QueryError_NOTFOUND { + return + } stream.SendError(qErr) return } @@ -386,14 +374,20 @@ func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) searchARN(ct return } - stream.SendItem(item) + if item != nil { + stream.SendItem(item) + } } // searchCustom Runs custom search logic using the `InputMapperSearch` function func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) searchCustom(ctx context.Context, scope string, query string, ignoreCache bool, stream discovery.QueryResultStream) { - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_SEARCH, scope, s.ItemType, query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_SEARCH, scope, s.ItemType, query, ignoreCache) defer done() if qErr != nil { + // For better semantics, convert cached NOTFOUND into empty result + if qErr.GetErrorType() == sdp.QueryError_NOTFOUND { + return + } stream.SendError(qErr) return } @@ -424,7 +418,7 @@ func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) processError // Only cache the error if is something that won't be fixed by retrying if sdpErr.GetErrorType() == sdp.QueryError_NOTFOUND || sdpErr.GetErrorType() == sdp.QueryError_NOSCOPE { - s.Cache().StoreError(ctx, sdpErr, s.cacheDuration(), cacheKey) + s.cache.StoreError(ctx, sdpErr, s.cacheDuration(), cacheKey) } } @@ -435,6 +429,9 @@ func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) processError // run the paginated or unpaginated query. This handles caching, error handling, // and post-search filtering if the query param is passed func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) describe(ctx context.Context, query *string, input Input, scope string, ck sdpcache.CacheKey, stream discovery.QueryResultStream) { + // Track whether any items were found + itemsSent := 0 + if s.Paginated() { paginator := s.PaginatorBuilder(s.Client, input) @@ -460,8 +457,9 @@ func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) describe(ctx } for _, item := range items { - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) stream.SendItem(item) + itemsSent++ } } } else { @@ -486,9 +484,30 @@ func (s *DescribeOnlyAdapter[Input, Output, ClientStruct, Options]) describe(ctx } for _, item := range items { - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) stream.SendItem(item) + itemsSent++ + } + } + + // Cache not-found when no items were found + if itemsSent == 0 { + var errorString string + if query != nil { + errorString = fmt.Sprintf("no %s found for search query '%s' in scope %s", s.ItemType, *query, scope) + } else { + errorString = fmt.Sprintf("no %s found in scope %s", s.ItemType, scope) + } + + notFoundErr := &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: errorString, + Scope: scope, + SourceName: s.Name(), + ItemType: s.ItemType, + ResponderName: s.Name(), } + s.cache.StoreError(ctx, notFoundErr, s.cacheDuration(), ck) } } diff --git a/aws-source/adapters/adapterhelpers_describe_source_test.go b/aws-source/adapters/adapterhelpers_describe_source_test.go index 75465674..323a500c 100644 --- a/aws-source/adapters/adapterhelpers_describe_source_test.go +++ b/aws-source/adapters/adapterhelpers_describe_source_test.go @@ -81,6 +81,7 @@ func TestGet(t *testing.T) { describeFuncCalled = true return "", nil }, + cache: sdpcache.NewNoOpCache(), } item, err := s.Get(context.Background(), "foo.eu-west-2", "bar", false) @@ -154,6 +155,7 @@ func TestGet(t *testing.T) { return "", nil }, UseListForGet: true, + cache: sdpcache.NewNoOpCache(), } item, err := s.Get(context.Background(), "foo.eu-west-2", uniqueAttributeValue, false) @@ -199,6 +201,7 @@ func TestGet(t *testing.T) { DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { return "", nil }, + cache: sdpcache.NewNoOpCache(), } _, err := s.Get(context.Background(), "foo.eu-west-2", "bar", false) @@ -225,6 +228,7 @@ func TestGet(t *testing.T) { DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { return "", nil }, + cache: sdpcache.NewNoOpCache(), } _, err := s.Get(context.Background(), "foo.eu-west-2", "bar", false) @@ -254,6 +258,7 @@ func TestSearchARN(t *testing.T) { DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { return "fancy", nil }, + cache: sdpcache.NewNoOpCache(), } stream := discovery.NewRecordingQueryResultStream() @@ -302,6 +307,7 @@ func TestSearchCustom(t *testing.T) { DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { return input, nil }, + cache: sdpcache.NewNoOpCache(), } stream := discovery.NewRecordingQueryResultStream() @@ -354,6 +360,7 @@ func TestNoInputMapper(t *testing.T) { DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { return "", nil }, + cache: sdpcache.NewNoOpCache(), } t.Run("Get", func(t *testing.T) { @@ -388,6 +395,7 @@ func TestNoOutputMapper(t *testing.T) { DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { return "", nil }, + cache: sdpcache.NewNoOpCache(), } t.Run("Get", func(t *testing.T) { @@ -424,6 +432,7 @@ func TestNoDescribeFunc(t *testing.T) { {}, }, nil }, + cache: sdpcache.NewNoOpCache(), } t.Run("Get", func(t *testing.T) { @@ -463,6 +472,7 @@ func TestFailingInputMapper(t *testing.T) { DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { return "", nil }, + cache: sdpcache.NewNoOpCache(), } fooBar := regexp.MustCompile("foobar") @@ -511,6 +521,7 @@ func TestFailingOutputMapper(t *testing.T) { DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { return "", nil }, + cache: sdpcache.NewNoOpCache(), } fooBar := regexp.MustCompile("foobar") @@ -561,6 +572,7 @@ func TestFailingDescribeFunc(t *testing.T) { DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { return "", errors.New("foobar") }, + cache: sdpcache.NewNoOpCache(), } fooBar := regexp.MustCompile("foobar") @@ -638,6 +650,7 @@ func TestPaginated(t *testing.T) { DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { return "", nil }, + cache: sdpcache.NewNoOpCache(), } t.Run("detecting pagination", func(t *testing.T) { @@ -675,7 +688,7 @@ func TestDescribeOnlySourceCaching(t *testing.T) { MaxResultsPerPage: 1, Region: "eu-west-2", AccountID: "foo", - cache: sdpcache.NewCache(ctx), + cache: sdpcache.NewMemoryCache(), InputMapperGet: func(scope, query string) (string, error) { return "input", nil }, @@ -837,3 +850,82 @@ func TestDescribeOnlySourceCaching(t *testing.T) { } }) } + +// TestListCachingZeroItems demonstrates that LIST caching works when 0 items are returned. +// This is a simple test to verify that repeated LIST calls don't hit the backend when +// the first call returned no items. +func TestListCachingZeroItems(t *testing.T) { + ctx := context.Background() + describeCalls := 0 + cache := sdpcache.NewMemoryCache() + + adapter := &DescribeOnlyAdapter[string, string, struct{}, struct{}]{ + ItemType: "ec2-instance", + Region: "us-east-1", + AccountID: "123456789012", + cache: cache, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "ec2-instance", + DescriptiveName: "EC2 Instance", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get an EC2 instance by ID", + ListDescription: "List all EC2 instances", + }, + }, + InputMapperGet: func(scope, query string) (string, error) { + return query, nil + }, + InputMapperList: func(scope string) (string, error) { + return "", nil + }, + DescribeFunc: func(ctx context.Context, client struct{}, input string) (string, error) { + describeCalls++ + t.Logf("DescribeFunc called (call #%d)", describeCalls) + return "", nil + }, + OutputMapper: func(ctx context.Context, client struct{}, scope, input, output string) ([]*sdp.Item, error) { + // Return empty slice - simulates no EC2 instances found + return []*sdp.Item{}, nil + }, + } + + // First LIST call - should hit the backend + stream1 := discovery.NewRecordingQueryResultStream() + adapter.ListStream(ctx, "123456789012.us-east-1", false, stream1) + + if describeCalls != 1 { + t.Errorf("First call: expected 1 DescribeFunc call, got %d", describeCalls) + } + if len(stream1.GetItems()) != 0 { + t.Errorf("First call: expected 0 items, got %d", len(stream1.GetItems())) + } + t.Logf("First call complete: %d items, %d errors", len(stream1.GetItems()), len(stream1.GetErrors())) + + // Second LIST call - should hit cache, NOT the backend + stream2 := discovery.NewRecordingQueryResultStream() + adapter.ListStream(ctx, "123456789012.us-east-1", false, stream2) + + if describeCalls != 1 { + t.Errorf("Second call: expected still 1 DescribeFunc call (cache hit), got %d", describeCalls) + } + if len(stream2.GetItems()) != 0 { + t.Errorf("Second call: expected 0 items, got %d", len(stream2.GetItems())) + } + // For backward compatibility, cached NOTFOUND is treated as empty result (no error) + // This matches the behavior of the first call which returns empty stream with no errors + if len(stream2.GetErrors()) != 0 { + t.Errorf("Second call: expected 0 errors from cache (backward compatibility), got %d errors", len(stream2.GetErrors())) + } + t.Logf("Second call complete: %d items, %d errors (cache hit!)", len(stream2.GetItems()), len(stream2.GetErrors())) + + // Third LIST call with ignoreCache=true - should bypass cache and hit backend + stream3 := discovery.NewRecordingQueryResultStream() + adapter.ListStream(ctx, "123456789012.us-east-1", true, stream3) // ignoreCache=true + + if describeCalls != 2 { + t.Errorf("Third call (ignoreCache): expected 2 DescribeFunc calls, got %d", describeCalls) + } + t.Logf("Third call (ignoreCache=true) complete: %d items, %d errors", len(stream3.GetItems()), len(stream3.GetErrors())) +} diff --git a/aws-source/adapters/adapterhelpers_get_list_adapter_v2.go b/aws-source/adapters/adapterhelpers_get_list_adapter_v2.go index d3a5f90e..b156e921 100644 --- a/aws-source/adapters/adapterhelpers_get_list_adapter_v2.go +++ b/aws-source/adapters/adapterhelpers_get_list_adapter_v2.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "sync" "time" "github.com/overmindtech/cli/discovery" @@ -24,7 +23,7 @@ type GetListAdapterV2[ListInput InputType, ListOutput OutputType, AWSItem AWSIte AdapterMetadata *sdp.AdapterMetadata CacheDuration time.Duration // How long to cache items for - cache sdpcache.Cache // The cache for this adapter (set during creation, can be nil for tests) + cache sdpcache.Cache // The cache for this adapter (set during creation, can be nil for tests) // Disables List(), meaning all calls will return empty results. This does // not affect Search() @@ -75,21 +74,6 @@ func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options] return s.CacheDuration } -var ( - noOpCacheV2Once sync.Once - noOpCacheV2 sdpcache.Cache -) - -func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) Cache() sdpcache.Cache { - if s.cache == nil { - noOpCacheV2Once.Do(func() { - noOpCacheV2 = sdpcache.NewNoOpCache() - }) - return noOpCacheV2 - } - return s.cache -} - // Validate Checks that the adapter has been set up correctly func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options]) Validate() error { if s.GetFunc == nil { @@ -171,7 +155,7 @@ func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options] } } - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.ItemType, query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.ItemType, query, ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -188,7 +172,7 @@ func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options] if err != nil { err := WrapAWSError(err) if !CanRetry(err) { - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) } return nil, err } @@ -207,7 +191,7 @@ func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options] } } - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) return item, nil } @@ -235,9 +219,13 @@ func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options] return } - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) defer done() if qErr != nil { + // For better semantics, convert cached NOTFOUND into empty result + if qErr.GetErrorType() == sdp.QueryError_NOTFOUND { + return + } stream.SendError(qErr) return } @@ -254,11 +242,16 @@ func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options] return } + // Track whether any items were found and if we had an error + itemsSent := 0 + hadError := false + // Define the function to send the outputs sendOutputs := func(out ListOutput) { // Extract the items in the correct format awsItems, err := s.ListExtractor(ctx, out, s.Client) if err != nil { + hadError = true stream.SendError(WrapAWSError(err)) return } @@ -268,6 +261,7 @@ func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options] for _, awsItem := range awsItems { item, err := s.ItemMapper(nil, scope, awsItem) if err != nil { + hadError = true stream.SendError(WrapAWSError(err)) continue } @@ -280,7 +274,8 @@ func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options] } stream.SendItem(item) - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + itemsSent++ + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) } } @@ -291,6 +286,7 @@ func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options] for paginator.HasMorePages() { out, err := paginator.NextPage(ctx) if err != nil { + hadError = true stream.SendError(WrapAWSError(err)) return } @@ -300,12 +296,27 @@ func (s *GetListAdapterV2[ListInput, ListOutput, AWSItem, ClientStruct, Options] } else if s.ListFunc != nil { out, err := s.ListFunc(ctx, s.Client, listInput) if err != nil { + hadError = true stream.SendError(WrapAWSError(err)) return } sendOutputs(out) } + + // Cache not-found only when no items were found AND no error occurred + // If we had an error, that error is already sent to the stream, don't overwrite it + if itemsSent == 0 && !hadError { + notFoundErr := &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: fmt.Sprintf("no %s found in scope %s", s.ItemType, scope), + Scope: scope, + SourceName: s.Name(), + ItemType: s.ItemType, + ResponderName: s.Name(), + } + s.cache.StoreError(ctx, notFoundErr, s.cacheDuration(), ck) + } } // Search Searches for AWS resources, this can be implemented either as a diff --git a/aws-source/adapters/adapterhelpers_get_list_adapter_v2_test.go b/aws-source/adapters/adapterhelpers_get_list_adapter_v2_test.go index 9b6dc42c..6ca9a41f 100644 --- a/aws-source/adapters/adapterhelpers_get_list_adapter_v2_test.go +++ b/aws-source/adapters/adapterhelpers_get_list_adapter_v2_test.go @@ -60,6 +60,7 @@ func TestGetListAdapterV2Get(t *testing.T) { "foo": "bar", }, nil }, + cache: sdpcache.NewNoOpCache(), } item, err := s.Get(context.Background(), "12345.eu-west-2", "", false) @@ -83,6 +84,7 @@ func TestGetListAdapterV2Get(t *testing.T) { ItemMapper: func(query *string, scope, awsItem string) (*sdp.Item, error) { return &sdp.Item{}, nil }, + cache: sdpcache.NewNoOpCache(), } if _, err := s.Get(context.Background(), "12345.eu-west-2", "", false); err == nil { @@ -101,6 +103,7 @@ func TestGetListAdapterV2Get(t *testing.T) { ItemMapper: func(query *string, scope, awsItem string) (*sdp.Item, error) { return &sdp.Item{}, errors.New("mapper error") }, + cache: sdpcache.NewNoOpCache(), } if _, err := s.Get(context.Background(), "12345.eu-west-2", "", false); err == nil { @@ -135,6 +138,7 @@ func TestGetListAdapterV2ListStream(t *testing.T) { InputMapperList: func(scope string) (string, error) { return "input", nil }, + cache: sdpcache.NewNoOpCache(), } stream := discovery.NewRecordingQueryResultStream() @@ -165,6 +169,7 @@ func TestGetListAdapterV2ListStream(t *testing.T) { ItemMapper: func(query *string, scope, awsItem string) (*sdp.Item, error) { return &sdp.Item{}, nil }, + cache: sdpcache.NewNoOpCache(), } stream := discovery.NewRecordingQueryResultStream() @@ -196,6 +201,7 @@ func TestGetListAdapterV2ListStream(t *testing.T) { InputMapperList: func(scope string) (string, error) { return "input", nil }, + cache: sdpcache.NewNoOpCache(), } stream := discovery.NewRecordingQueryResultStream() @@ -268,6 +274,7 @@ func TestListFuncPaginatorBuilder(t *testing.T) { GetFunc: func(ctx context.Context, client struct{}, scope, query string) (string, error) { return "", nil }, + cache: sdpcache.NewNoOpCache(), } stream := discovery.NewRecordingQueryResultStream() @@ -291,7 +298,7 @@ func TestGetListAdapterV2Caching(t *testing.T) { ItemType: "test-type", Region: "eu-west-2", AccountID: "foo", - cache: sdpcache.NewCache(ctx), + cache: sdpcache.NewCache(ctx), GetFunc: func(ctx context.Context, client struct{}, scope, query string) (string, error) { generation += 1 return fmt.Sprintf("%v", generation), nil @@ -400,3 +407,84 @@ func TestGetListAdapterV2Caching(t *testing.T) { } }) } + +// TestGetListAdapterV2_ListExtractorErrorNoNotFoundCache tests that when ListExtractor fails, +// we don't incorrectly cache NOTFOUND. The error should be sent, but NOTFOUND should not be cached +// because the failure was due to extraction errors, not because items don't exist. +func TestGetListAdapterV2_ListExtractorErrorNoNotFoundCache(t *testing.T) { + ctx := context.Background() + cache := sdpcache.NewMemoryCache() + listCalls := 0 + + type MockAWSItem struct { + Name string + } + + adapter := &GetListAdapterV2[*MockInput, *MockOutput, *MockAWSItem, *MockClient, *MockOptions]{ + ItemType: "test-item", + cache: cache, + AccountID: "123456789012", + Region: "us-east-1", + GetFunc: func(ctx context.Context, client *MockClient, scope string, query string) (*MockAWSItem, error) { + return nil, errors.New("should not be called in LIST test") + }, + InputMapperList: func(scope string) (*MockInput, error) { + return &MockInput{}, nil + }, + ListFunc: func(ctx context.Context, client *MockClient, input *MockInput) (*MockOutput, error) { + listCalls++ + // Return a valid output that indicates items exist + return &MockOutput{}, nil + }, + ListExtractor: func(ctx context.Context, output *MockOutput, client *MockClient) ([]*MockAWSItem, error) { + // Simulate extraction failure - this should NOT result in NOTFOUND caching + return nil, errors.New("extraction failed") + }, + ItemMapper: func(query *string, scope string, awsItem *MockAWSItem) (*sdp.Item, error) { + return &sdp.Item{ + Type: "test-item", + UniqueAttribute: "name", + Attributes: &sdp.ItemAttributes{}, + Scope: scope, + }, nil + }, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "test-item", + DescriptiveName: "Test Item", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get a test item", + ListDescription: "List all test items", + }, + }, + } + + // First call - ListExtractor fails, should send error but NOT cache NOTFOUND + stream1 := discovery.NewRecordingQueryResultStream() + adapter.ListStream(ctx, "123456789012.us-east-1", false, stream1) + + if len(stream1.GetItems()) != 0 { + t.Errorf("Expected 0 items, got %d", len(stream1.GetItems())) + } + if len(stream1.GetErrors()) != 1 { + t.Errorf("Expected 1 error from ListExtractor failure, got %d", len(stream1.GetErrors())) + } + if listCalls != 1 { + t.Errorf("Expected 1 ListFunc call, got %d", listCalls) + } + + // Second call - should NOT hit cache (NOTFOUND was not cached), should try again + stream2 := discovery.NewRecordingQueryResultStream() + adapter.ListStream(ctx, "123456789012.us-east-1", false, stream2) + + if listCalls != 2 { + t.Errorf("Expected 2 ListFunc calls (no cache hit because NOTFOUND was not cached), got %d", listCalls) + } + if len(stream2.GetItems()) != 0 { + t.Errorf("Expected 0 items, got %d", len(stream2.GetItems())) + } + if len(stream2.GetErrors()) != 1 { + t.Errorf("Expected 1 error from ListExtractor failure, got %d", len(stream2.GetErrors())) + } +} diff --git a/aws-source/adapters/adapterhelpers_get_list_source.go b/aws-source/adapters/adapterhelpers_get_list_source.go index c1e6d1b0..abeb23d4 100644 --- a/aws-source/adapters/adapterhelpers_get_list_source.go +++ b/aws-source/adapters/adapterhelpers_get_list_source.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "sync" "time" "buf.build/go/protovalidate" @@ -23,7 +22,7 @@ type GetListAdapter[AWSItem AWSItemType, ClientStruct ClientStructType, Options AdapterMetadata *sdp.AdapterMetadata CacheDuration time.Duration // How long to cache items for - cache sdpcache.Cache // The cache for this adapter (set during creation, can be nil for tests) + cache sdpcache.Cache // The cache for this adapter (set during creation, can be nil for tests) // Disables List(), meaning all calls will return empty results. This does // not affect Search() @@ -56,21 +55,6 @@ func (s *GetListAdapter[AWSItem, ClientStruct, Options]) cacheDuration() time.Du return s.CacheDuration } -var ( - noOpCacheOnce sync.Once - noOpCache sdpcache.Cache -) - -func (s *GetListAdapter[AWSItem, ClientStruct, Options]) Cache() sdpcache.Cache { - if s.cache == nil { - noOpCacheOnce.Do(func() { - noOpCache = sdpcache.NewNoOpCache() - }) - return noOpCache - } - return s.cache -} - // Validate Checks that the adapter has been set up correctly func (s *GetListAdapter[AWSItem, ClientStruct, Options]) Validate() error { if s.GetFunc == nil { @@ -144,7 +128,7 @@ func (s *GetListAdapter[AWSItem, ClientStruct, Options]) Get(ctx context.Context } } - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.ItemType, query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.ItemType, query, ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -161,7 +145,7 @@ func (s *GetListAdapter[AWSItem, ClientStruct, Options]) Get(ctx context.Context if err != nil { err := WrapAWSError(err) if !CanRetry(err) { - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) } return nil, err } @@ -180,7 +164,7 @@ func (s *GetListAdapter[AWSItem, ClientStruct, Options]) Get(ctx context.Context } } - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) return item, nil } @@ -199,9 +183,13 @@ func (s *GetListAdapter[AWSItem, ClientStruct, Options]) List(ctx context.Contex return []*sdp.Item{}, nil } - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.ItemType, "", ignoreCache) defer done() if qErr != nil { + // For better semantics, convert cached NOTFOUND into empty result + if qErr.GetErrorType() == sdp.QueryError_NOTFOUND { + return []*sdp.Item{}, nil + } return nil, qErr } if cacheHit { @@ -212,15 +200,17 @@ func (s *GetListAdapter[AWSItem, ClientStruct, Options]) List(ctx context.Contex if err != nil { err := WrapAWSError(err) if !CanRetry(err) { - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) } return nil, err } items := make([]*sdp.Item, 0) + hadError := false for _, awsItem := range awsItems { item, err := s.ItemMapper("", scope, awsItem) if err != nil { + hadError = true continue } @@ -232,7 +222,20 @@ func (s *GetListAdapter[AWSItem, ClientStruct, Options]) List(ctx context.Contex } items = append(items, item) - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) + } + + // Cache not-found only when no items were found AND no error occurred + if len(items) == 0 && !hadError { + notFoundErr := &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: fmt.Sprintf("no %s found in scope %s", s.ItemType, scope), + Scope: scope, + SourceName: s.Name(), + ItemType: s.ItemType, + ResponderName: s.Name(), + } + s.cache.StoreError(ctx, notFoundErr, s.cacheDuration(), ck) } return items, nil @@ -296,15 +299,22 @@ func (s *GetListAdapter[AWSItem, ClientStruct, Options]) SearchARN(ctx context.C return nil, WrapAWSError(err) } - return []*sdp.Item{item}, nil + if item != nil { + return []*sdp.Item{item}, nil + } + return []*sdp.Item{}, nil } // Custom search function that can be used to search for items in a different, // adapter-specific way func (s *GetListAdapter[AWSItem, ClientStruct, Options]) SearchCustom(ctx context.Context, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_SEARCH, scope, s.ItemType, query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_SEARCH, scope, s.ItemType, query, ignoreCache) defer done() if qErr != nil { + // For better semantics, convert cached NOTFOUND into empty result + if qErr.GetErrorType() == sdp.QueryError_NOTFOUND { + return []*sdp.Item{}, nil + } return nil, qErr } if cacheHit { @@ -314,21 +324,36 @@ func (s *GetListAdapter[AWSItem, ClientStruct, Options]) SearchCustom(ctx contex awsItems, err := s.SearchFunc(ctx, s.Client, scope, query) if err != nil { err = WrapAWSError(err) - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) return nil, err } items := make([]*sdp.Item, 0) + hadError := false var item *sdp.Item for _, awsItem := range awsItems { item, err = s.ItemMapper(query, scope, awsItem) if err != nil { + hadError = true continue } items = append(items, item) - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) + } + + // Cache not-found only when no items were found AND no error occurred + if len(items) == 0 && !hadError { + notFoundErr := &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: fmt.Sprintf("no %s found for search query '%s' in scope %s", s.ItemType, query, scope), + Scope: scope, + SourceName: s.Name(), + ItemType: s.ItemType, + ResponderName: s.Name(), + } + s.cache.StoreError(ctx, notFoundErr, s.cacheDuration(), ck) } return items, nil diff --git a/aws-source/adapters/adapterhelpers_get_list_source_test.go b/aws-source/adapters/adapterhelpers_get_list_source_test.go index b8eddc40..b0601b84 100644 --- a/aws-source/adapters/adapterhelpers_get_list_source_test.go +++ b/aws-source/adapters/adapterhelpers_get_list_source_test.go @@ -62,6 +62,7 @@ func TestGetListSourceGet(t *testing.T) { "foo": "bar", }, nil }, + cache: sdpcache.NewNoOpCache(), } item, err := s.Get(context.Background(), "12345.eu-west-2", "", false) @@ -88,6 +89,7 @@ func TestGetListSourceGet(t *testing.T) { ItemMapper: func(query, scope string, awsItem string) (*sdp.Item, error) { return &sdp.Item{}, nil }, + cache: sdpcache.NewNoOpCache(), } if _, err := s.Get(context.Background(), "12345.eu-west-2", "", false); err == nil { @@ -109,6 +111,7 @@ func TestGetListSourceGet(t *testing.T) { ItemMapper: func(query, scope string, awsItem string) (*sdp.Item, error) { return &sdp.Item{}, errors.New("mapper error") }, + cache: sdpcache.NewNoOpCache(), } if _, err := s.Get(context.Background(), "12345.eu-west-2", "", false); err == nil { @@ -137,6 +140,7 @@ func TestGetListSourceList(t *testing.T) { "foo": "bar", }, nil }, + cache: sdpcache.NewNoOpCache(), } if items, err := s.List(context.Background(), "12345.eu-west-2", false); err != nil { @@ -166,6 +170,7 @@ func TestGetListSourceList(t *testing.T) { ItemMapper: func(query, scope string, awsItem string) (*sdp.Item, error) { return &sdp.Item{}, nil }, + cache: sdpcache.NewNoOpCache(), } if _, err := s.List(context.Background(), "12345.eu-west-2", false); err == nil { @@ -187,6 +192,7 @@ func TestGetListSourceList(t *testing.T) { ItemMapper: func(query, scope string, awsItem string) (*sdp.Item, error) { return &sdp.Item{}, errors.New("mapper error") }, + cache: sdpcache.NewNoOpCache(), } if items, err := s.List(context.Background(), "12345.eu-west-2", false); err != nil { @@ -214,6 +220,7 @@ func TestGetListSourceSearch(t *testing.T) { ItemMapper: func(query, scope string, awsItem string) (*sdp.Item, error) { return &sdp.Item{}, nil }, + cache: sdpcache.NewNoOpCache(), } t.Run("bad ARN", func(t *testing.T) { @@ -248,7 +255,7 @@ func TestGetListSourceCaching(t *testing.T) { ItemType: "test-type", Region: "eu-west-2", AccountID: "foo", - cache: sdpcache.NewMemoryCache(), + cache: sdpcache.NewMemoryCache(), GetFunc: func(ctx context.Context, client struct{}, scope, query string) (string, error) { generation += 1 return fmt.Sprintf("%v", generation), nil diff --git a/aws-source/adapters/adapterhelpers_notfound_cache_test.go b/aws-source/adapters/adapterhelpers_notfound_cache_test.go new file mode 100644 index 00000000..3b17fc1d --- /dev/null +++ b/aws-source/adapters/adapterhelpers_notfound_cache_test.go @@ -0,0 +1,804 @@ +package adapters + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" +) + +// TestGetListAdapterV2_GetNotFoundCaching tests that GetListAdapterV2 caches not-found error results +func TestGetListAdapterV2_GetNotFoundCaching(t *testing.T) { + ctx := context.Background() + cache := sdpcache.NewMemoryCache() + getCalls := 0 + + // Mock AWS item type + type MockAWSItem struct { + Name string + } + + adapter := &GetListAdapterV2[*MockInput, *MockOutput, *MockAWSItem, *MockClient, *MockOptions]{ + ItemType: "test-item", + cache: cache, + AccountID: "123456789012", + Region: "us-east-1", + GetFunc: func(ctx context.Context, client *MockClient, scope string, query string) (*MockAWSItem, error) { + getCalls++ + // Return NOTFOUND error (typical AWS behavior) + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: "resource not found", + Scope: scope, + } + }, + ItemMapper: func(query *string, scope string, awsItem *MockAWSItem) (*sdp.Item, error) { + return &sdp.Item{ + Type: "test-item", + UniqueAttribute: "name", + Attributes: &sdp.ItemAttributes{}, + Scope: scope, + }, nil + }, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "test-item", + DescriptiveName: "Test Item", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get a test item", + ListDescription: "List all test items", + }, + }, + } + + // First call should invoke GetFunc and get error + item, err := adapter.Get(ctx, "123456789012.us-east-1", "test-query", false) + if item != nil { + t.Errorf("Expected nil item, got %v", item) + } + // First call returns the error (but it's cached) + if err == nil { + t.Error("Expected NOTFOUND error, got nil") + } + if getCalls != 1 { + t.Errorf("Expected 1 GetFunc call, got %d", getCalls) + } + + // Second call should hit cache and return the cached NOTFOUND error + item, err = adapter.Get(ctx, "123456789012.us-east-1", "test-query", false) + if item != nil { + t.Errorf("Expected nil item on cache hit, got %v", item) + } + var qErr *sdp.QueryError + if err == nil { + t.Error("Expected NOTFOUND error on cache hit, got nil") + } else if !errors.As(err, &qErr) || qErr.GetErrorType() != sdp.QueryError_NOTFOUND { + t.Errorf("Expected NOTFOUND error on cache hit, got %v", err) + } + if getCalls != 1 { + t.Errorf("Expected still 1 GetFunc call (cache hit), got %d", getCalls) + } +} + +// TestGetListAdapterV2_ListNotFoundCaching tests that GetListAdapterV2 caches not-found results when LIST returns 0 items +func TestGetListAdapterV2_ListNotFoundCaching(t *testing.T) { + ctx := context.Background() + cache := sdpcache.NewMemoryCache() + listCalls := 0 + + type MockAWSItem struct { + Name string + } + + adapter := &GetListAdapterV2[*MockInput, *MockOutput, *MockAWSItem, *MockClient, *MockOptions]{ + ItemType: "test-item", + cache: cache, + AccountID: "123456789012", + Region: "us-east-1", + GetFunc: func(ctx context.Context, client *MockClient, scope string, query string) (*MockAWSItem, error) { + return nil, errors.New("should not be called in LIST test") + }, + InputMapperList: func(scope string) (*MockInput, error) { + return &MockInput{}, nil + }, + ListFunc: func(ctx context.Context, client *MockClient, input *MockInput) (*MockOutput, error) { + listCalls++ + return &MockOutput{}, nil + }, + ListExtractor: func(ctx context.Context, output *MockOutput, client *MockClient) ([]*MockAWSItem, error) { + // Return empty slice to simulate no items found + return []*MockAWSItem{}, nil + }, + ItemMapper: func(query *string, scope string, awsItem *MockAWSItem) (*sdp.Item, error) { + return &sdp.Item{ + Type: "test-item", + UniqueAttribute: "name", + Attributes: &sdp.ItemAttributes{}, + Scope: scope, + }, nil + }, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "test-item", + DescriptiveName: "Test Item", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get a test item", + ListDescription: "List all test items", + }, + }, + } + + // Use test stream to collect results + stream := &testQueryResultStream{} + + // First call should invoke ListFunc + adapter.ListStream(ctx, "123456789012.us-east-1", false, stream) + if len(stream.items) != 0 { + t.Errorf("Expected 0 items, got %d", len(stream.items)) + } + if listCalls != 1 { + t.Errorf("Expected 1 ListFunc call, got %d", listCalls) + } + + // Second call should hit cache + stream2 := &testQueryResultStream{} + adapter.ListStream(ctx, "123456789012.us-east-1", false, stream2) + if len(stream2.items) != 0 { + t.Errorf("Expected 0 items on cache hit, got %d", len(stream2.items)) + } + // For backward compatibility, cached NOTFOUND is treated as empty result (no error) + // This matches the behavior of the first call which returns empty stream with no errors + if len(stream2.errors) != 0 { + t.Errorf("Expected 0 errors from cache (backward compatibility), got %d", len(stream2.errors)) + } + if listCalls != 1 { + t.Errorf("Expected still 1 ListFunc call (cache hit), got %d", listCalls) + } +} + +// TestGetListAdapter_GetNotFoundCaching tests GetListAdapter's GET not-found caching +func TestGetListAdapter_GetNotFoundCaching(t *testing.T) { + ctx := context.Background() + cache := sdpcache.NewMemoryCache() + getCalls := 0 + + type MockAWSItem struct { + Name string + } + + adapter := &GetListAdapter[*MockAWSItem, *MockClient, *MockOptions]{ + ItemType: "test-item", + cache: cache, + AccountID: "123456789012", + Region: "us-east-1", + GetFunc: func(ctx context.Context, client *MockClient, scope string, query string) (*MockAWSItem, error) { + getCalls++ + // Return NOTFOUND error (typical AWS behavior) + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: "resource not found", + Scope: scope, + } + }, + ItemMapper: func(query, scope string, awsItem *MockAWSItem) (*sdp.Item, error) { + return &sdp.Item{ + Type: "test-item", + UniqueAttribute: "name", + Attributes: &sdp.ItemAttributes{}, + Scope: scope, + }, nil + }, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "test-item", + DescriptiveName: "Test Item", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get a test item", + ListDescription: "List all test items", + }, + }, + } + + // First call returns error (which gets cached) + item, err := adapter.Get(ctx, "123456789012.us-east-1", "test-query", false) + if item != nil { + t.Errorf("Expected nil item, got %v", item) + } + if err == nil { + t.Error("Expected NOTFOUND error, got nil") + } + if getCalls != 1 { + t.Errorf("Expected 1 GetFunc call, got %d", getCalls) + } + + // Second call should hit cache and return the cached NOTFOUND error + item, err = adapter.Get(ctx, "123456789012.us-east-1", "test-query", false) + if item != nil { + t.Errorf("Expected nil item on cache hit, got %v", item) + } + var qErr *sdp.QueryError + if err == nil { + t.Error("Expected NOTFOUND error on cache hit, got nil") + } else if !errors.As(err, &qErr) || qErr.GetErrorType() != sdp.QueryError_NOTFOUND { + t.Errorf("Expected NOTFOUND error on cache hit, got %v", err) + } + if getCalls != 1 { + t.Errorf("Expected still 1 GetFunc call (cache hit), got %d", getCalls) + } +} + +// TestGetListAdapter_ListNotFoundCaching tests GetListAdapter's LIST not-found caching +func TestGetListAdapter_ListNotFoundCaching(t *testing.T) { + ctx := context.Background() + cache := sdpcache.NewMemoryCache() + listCalls := 0 + + type MockAWSItem struct { + Name string + } + + adapter := &GetListAdapter[*MockAWSItem, *MockClient, *MockOptions]{ + ItemType: "test-item", + cache: cache, + AccountID: "123456789012", + Region: "us-east-1", + GetFunc: func(ctx context.Context, client *MockClient, scope string, query string) (*MockAWSItem, error) { + return nil, errors.New("should not be called") + }, + ListFunc: func(ctx context.Context, client *MockClient, scope string) ([]*MockAWSItem, error) { + listCalls++ + return []*MockAWSItem{}, nil // Empty list + }, + ItemMapper: func(query, scope string, awsItem *MockAWSItem) (*sdp.Item, error) { + return &sdp.Item{ + Type: "test-item", + UniqueAttribute: "name", + Attributes: &sdp.ItemAttributes{}, + Scope: scope, + }, nil + }, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "test-item", + DescriptiveName: "Test Item", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get a test item", + ListDescription: "List all test items", + }, + }, + } + + // First call + items, err := adapter.List(ctx, "123456789012.us-east-1", false) + if len(items) != 0 { + t.Errorf("Expected 0 items, got %d", len(items)) + } + if err != nil { + t.Errorf("Expected nil error, got %v", err) + } + if listCalls != 1 { + t.Errorf("Expected 1 ListFunc call, got %d", listCalls) + } + + // Second call should hit cache and return empty result with nil error (backward compatibility) + items2, err := adapter.List(ctx, "123456789012.us-east-1", false) + // Should get empty result with nil error for backward compatibility + if len(items2) != 0 { + t.Errorf("Expected 0 items from cache, got %d", len(items2)) + } + if err != nil { + t.Errorf("Expected nil error from cache (backward compat), got %v", err) + } + if listCalls != 1 { + t.Errorf("Expected still 1 ListFunc call (cache hit), got %d", listCalls) + } +} + +// TestAlwaysGetAdapter_GetNotFoundCaching tests AlwaysGetAdapter's GET not-found caching +func TestAlwaysGetAdapter_GetNotFoundCaching(t *testing.T) { + ctx := context.Background() + cache := sdpcache.NewMemoryCache() + getFuncCalls := 0 + + adapter := &AlwaysGetAdapter[*MockInput, *MockOutput, *MockGetInput, *MockGetOutput, *MockClient, *MockOptions]{ + ItemType: "test-item", + cache: cache, + AccountID: "123456789012", + Region: "us-east-1", + GetInputMapper: func(scope, query string) *MockGetInput { + return &MockGetInput{} + }, + GetFunc: func(ctx context.Context, client *MockClient, scope string, input *MockGetInput) (*sdp.Item, error) { + getFuncCalls++ + // Return NOTFOUND error (typical AWS behavior) + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: "resource not found", + Scope: scope, + } + }, + // Add ListFuncPaginatorBuilder to avoid validation error + ListFuncPaginatorBuilder: func(client *MockClient, input *MockInput) Paginator[*MockOutput, *MockOptions] { + return nil // Not used in GET test + }, + ListFuncOutputMapper: func(output *MockOutput, input *MockInput) ([]*MockGetInput, error) { + return nil, nil // Not used in GET test + }, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "test-item", + DescriptiveName: "Test Item", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get a test item", + ListDescription: "List all test items", + }, + }, + } + + // First call returns error (which gets cached) + item, err := adapter.Get(ctx, "123456789012.us-east-1", "test-query", false) + if item != nil { + t.Errorf("Expected nil item, got %v", item) + } + if err == nil { + t.Error("Expected NOTFOUND error, got nil") + } + if getFuncCalls != 1 { + t.Errorf("Expected 1 GetFunc call, got %d", getFuncCalls) + } + + // Second call should hit cache and return the cached NOTFOUND error + item, err = adapter.Get(ctx, "123456789012.us-east-1", "test-query", false) + if item != nil { + t.Errorf("Expected nil item on cache hit, got %v", item) + } + var qErr *sdp.QueryError + if err == nil { + t.Error("Expected NOTFOUND error on cache hit, got nil") + } else if !errors.As(err, &qErr) || qErr.GetErrorType() != sdp.QueryError_NOTFOUND { + t.Errorf("Expected NOTFOUND error on cache hit, got %v", err) + } + if getFuncCalls != 1 { + t.Errorf("Expected still 1 GetFunc call (cache hit), got %d", getFuncCalls) + } +} + +// TestDescribeOnlyAdapter_ListNotFoundCaching tests DescribeOnlyAdapter's LIST not-found caching +func TestDescribeOnlyAdapter_ListNotFoundCaching(t *testing.T) { + ctx := context.Background() + cache := sdpcache.NewMemoryCache() + describeCalls := 0 + + adapter := &DescribeOnlyAdapter[*MockInput, *MockOutput, *MockClient, *MockOptions]{ + ItemType: "test-item", + cache: cache, + AccountID: "123456789012", + Region: "us-east-1", + MaxResultsPerPage: 100, // Set to avoid validation using default + DescribeFunc: func(ctx context.Context, client *MockClient, input *MockInput) (*MockOutput, error) { + describeCalls++ + return &MockOutput{}, nil + }, + InputMapperGet: func(scope, query string) (*MockInput, error) { + return &MockInput{}, nil + }, + InputMapperList: func(scope string) (*MockInput, error) { + return &MockInput{}, nil + }, + OutputMapper: func(ctx context.Context, client *MockClient, scope string, input *MockInput, output *MockOutput) ([]*sdp.Item, error) { + // Return empty slice to simulate no items found + return []*sdp.Item{}, nil + }, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "test-item", + DescriptiveName: "Test Item", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get a test item", + ListDescription: "List all test items", + }, + }, + } + + stream := &testQueryResultStream{} + + // First call + adapter.ListStream(ctx, "123456789012.us-east-1", false, stream) + if len(stream.items) != 0 { + t.Errorf("Expected 0 items, got %d", len(stream.items)) + } + if describeCalls != 1 { + t.Errorf("Expected 1 DescribeFunc call, got %d", describeCalls) + } + + // Second call should hit cache + stream2 := &testQueryResultStream{} + adapter.ListStream(ctx, "123456789012.us-east-1", false, stream2) + if len(stream2.items) != 0 { + t.Errorf("Expected 0 items on cache hit, got %d", len(stream2.items)) + } + // For backward compatibility, cached NOTFOUND is treated as empty result (no error) + // This matches the behavior of the first call which returns empty stream with no errors + if len(stream2.errors) != 0 { + t.Errorf("Expected 0 errors from cache (backward compatibility), got %d", len(stream2.errors)) + } + if describeCalls != 1 { + t.Errorf("Expected still 1 DescribeFunc call (cache hit), got %d", describeCalls) + } +} + +// Mock types for testing +type MockClient struct{} +type MockInput struct{} +type MockOutput struct{} +type MockGetInput struct{} +type MockGetOutput struct{} +type MockOptions struct{} + +// testQueryResultStream is a simple implementation of QueryResultStream for testing +type testQueryResultStream struct { + items []*sdp.Item + errors []*sdp.QueryError +} + +func (s *testQueryResultStream) SendItem(item *sdp.Item) { + s.items = append(s.items, item) +} + +func (s *testQueryResultStream) SendError(err error) { + var qErr *sdp.QueryError + if errors.As(err, &qErr) { + s.errors = append(s.errors, qErr) + } else { + s.errors = append(s.errors, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: err.Error(), + }) + } +} + +// TestNotFoundCacheExpiry tests that not-found cache entries expire correctly +func TestNotFoundCacheExpiry(t *testing.T) { + ctx := context.Background() + cache := sdpcache.NewMemoryCache() + getFuncCalls := 0 + + type MockAWSItem struct { + Name string + } + + adapter := &GetListAdapterV2[*MockInput, *MockOutput, *MockAWSItem, *MockClient, *MockOptions]{ + ItemType: "test-item", + cache: cache, + CacheDuration: 100 * time.Millisecond, // Short duration for testing + AccountID: "123456789012", + Region: "us-east-1", + GetFunc: func(ctx context.Context, client *MockClient, scope string, query string) (*MockAWSItem, error) { + getFuncCalls++ + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: "not found", + } + }, + ItemMapper: func(query *string, scope string, awsItem *MockAWSItem) (*sdp.Item, error) { + return &sdp.Item{ + Type: "test-item", + UniqueAttribute: "name", + Attributes: &sdp.ItemAttributes{}, + Scope: scope, + }, nil + }, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "test-item", + DescriptiveName: "Test Item", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get a test item", + ListDescription: "List all test items", + }, + }, + } + + // First call - should cache not-found + _, _ = adapter.Get(ctx, "123456789012.us-east-1", "test-query", false) + if getFuncCalls != 1 { + t.Errorf("Expected 1 GetFunc call, got %d", getFuncCalls) + } + + // Immediate second call - should hit cache + _, _ = adapter.Get(ctx, "123456789012.us-east-1", "test-query", false) + if getFuncCalls != 1 { + t.Errorf("Expected still 1 GetFunc call (cache hit), got %d", getFuncCalls) + } + + // Wait for cache to expire + time.Sleep(150 * time.Millisecond) + + // Third call after expiry - should invoke GetFunc again + _, _ = adapter.Get(ctx, "123456789012.us-east-1", "test-query", false) + if getFuncCalls != 2 { + t.Errorf("Expected 2 GetFunc calls (cache expired), got %d", getFuncCalls) + } +} + +// TestNotFoundCacheIgnoreCache tests that ignoreCache parameter bypasses not-found cache +func TestNotFoundCacheIgnoreCache(t *testing.T) { + ctx := context.Background() + cache := sdpcache.NewMemoryCache() + getFuncCalls := 0 + + type MockAWSItem struct { + Name string + } + + adapter := &GetListAdapterV2[*MockInput, *MockOutput, *MockAWSItem, *MockClient, *MockOptions]{ + ItemType: "test-item", + cache: cache, + AccountID: "123456789012", + Region: "us-east-1", + GetFunc: func(ctx context.Context, client *MockClient, scope string, query string) (*MockAWSItem, error) { + getFuncCalls++ + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: "not found", + } + }, + ItemMapper: func(query *string, scope string, awsItem *MockAWSItem) (*sdp.Item, error) { + return &sdp.Item{ + Type: "test-item", + UniqueAttribute: "name", + Attributes: &sdp.ItemAttributes{}, + Scope: scope, + }, nil + }, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "test-item", + DescriptiveName: "Test Item", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get a test item", + ListDescription: "List all test items", + }, + }, + } + + // First call with ignoreCache=false + _, _ = adapter.Get(ctx, "123456789012.us-east-1", "test-query", false) + if getFuncCalls != 1 { + t.Errorf("Expected 1 GetFunc call, got %d", getFuncCalls) + } + + // Second call with ignoreCache=true - should bypass cache + _, _ = adapter.Get(ctx, "123456789012.us-east-1", "test-query", true) + if getFuncCalls != 2 { + t.Errorf("Expected 2 GetFunc calls (ignore cache), got %d", getFuncCalls) + } + + // Third call with ignoreCache=false - should still hit cache from first call + _, _ = adapter.Get(ctx, "123456789012.us-east-1", "test-query", false) + if getFuncCalls != 2 { + t.Errorf("Expected still 2 GetFunc calls (cache hit), got %d", getFuncCalls) + } +} + +// TestNotFoundCacheDifferentQueries tests that different queries get separate cache entries +func TestNotFoundCacheDifferentQueries(t *testing.T) { + ctx := context.Background() + cache := sdpcache.NewMemoryCache() + getFuncCalls := 0 + queriesReceived := make(map[string]int) + + type MockAWSItem struct { + Name string + } + + adapter := &GetListAdapterV2[*MockInput, *MockOutput, *MockAWSItem, *MockClient, *MockOptions]{ + ItemType: "test-item", + cache: cache, + AccountID: "123456789012", + Region: "us-east-1", + GetFunc: func(ctx context.Context, client *MockClient, scope string, query string) (*MockAWSItem, error) { + getFuncCalls++ + queriesReceived[query]++ + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: "not found", + } + }, + ItemMapper: func(query *string, scope string, awsItem *MockAWSItem) (*sdp.Item, error) { + return &sdp.Item{ + Type: "test-item", + UniqueAttribute: "name", + Attributes: &sdp.ItemAttributes{}, + Scope: scope, + }, nil + }, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "test-item", + DescriptiveName: "Test Item", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get a test item", + ListDescription: "List all test items", + }, + }, + } + + // Query for item1 + _, _ = adapter.Get(ctx, "123456789012.us-east-1", "item1", false) + _, _ = adapter.Get(ctx, "123456789012.us-east-1", "item1", false) // Cache hit + + // Query for item2 + _, _ = adapter.Get(ctx, "123456789012.us-east-1", "item2", false) + _, _ = adapter.Get(ctx, "123456789012.us-east-1", "item2", false) // Cache hit + + // Should have called GetFunc once per unique query + if getFuncCalls != 2 { + t.Errorf("Expected 2 GetFunc calls (1 per unique query), got %d", getFuncCalls) + } + + if queriesReceived["item1"] != 1 { + t.Errorf("Expected 1 call for item1, got %d", queriesReceived["item1"]) + } + + if queriesReceived["item2"] != 1 { + t.Errorf("Expected 1 call for item2, got %d", queriesReceived["item2"]) + } +} + +// TestGetListAdapter_ListItemMapperErrorNoNotFoundCache tests that when ListFunc returns items +// but ItemMapper fails for all of them, we don't incorrectly cache NOTFOUND. Items actually exist +// but couldn't be mapped, so NOTFOUND should not be cached. +func TestGetListAdapter_ListItemMapperErrorNoNotFoundCache(t *testing.T) { + ctx := context.Background() + cache := sdpcache.NewMemoryCache() + listCalls := 0 + + type MockAWSItem struct { + Name string + } + + adapter := &GetListAdapter[*MockAWSItem, *MockClient, *MockOptions]{ + ItemType: "test-item", + cache: cache, + AccountID: "123456789012", + Region: "us-east-1", + GetFunc: func(ctx context.Context, client *MockClient, scope string, query string) (*MockAWSItem, error) { + return nil, errors.New("should not be called in LIST test") + }, + ListFunc: func(ctx context.Context, client *MockClient, scope string) ([]*MockAWSItem, error) { + listCalls++ + // Return items that exist + return []*MockAWSItem{ + {Name: "item1"}, + {Name: "item2"}, + }, nil + }, + ItemMapper: func(query, scope string, awsItem *MockAWSItem) (*sdp.Item, error) { + // Simulate mapping failure for all items - this should NOT result in NOTFOUND caching + return nil, errors.New("mapping failed") + }, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "test-item", + DescriptiveName: "Test Item", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get a test item", + ListDescription: "List all test items", + }, + }, + } + + // First call - ItemMapper fails for all items, should NOT cache NOTFOUND + items1, err1 := adapter.List(ctx, "123456789012.us-east-1", false) + + if len(items1) != 0 { + t.Errorf("Expected 0 items (all mapping failed), got %d", len(items1)) + } + if err1 != nil { + t.Errorf("Expected nil error (errors are silently ignored via continue), got %v", err1) + } + if listCalls != 1 { + t.Errorf("Expected 1 ListFunc call, got %d", listCalls) + } + + // Second call - should NOT hit cache (NOTFOUND was not cached), should try again + items2, err2 := adapter.List(ctx, "123456789012.us-east-1", false) + + if listCalls != 2 { + t.Errorf("Expected 2 ListFunc calls (no cache hit because NOTFOUND was not cached), got %d", listCalls) + } + if len(items2) != 0 { + t.Errorf("Expected 0 items, got %d", len(items2)) + } + if err2 != nil { + t.Errorf("Expected nil error, got %v", err2) + } +} + +// TestGetListAdapter_SearchCustomItemMapperErrorNoNotFoundCache tests that when SearchFunc returns items +// but ItemMapper fails for all of them, we don't incorrectly cache NOTFOUND. Items actually exist +// but couldn't be mapped, so NOTFOUND should not be cached. +func TestGetListAdapter_SearchCustomItemMapperErrorNoNotFoundCache(t *testing.T) { + ctx := context.Background() + cache := sdpcache.NewMemoryCache() + searchCalls := 0 + + type MockAWSItem struct { + Name string + } + + adapter := &GetListAdapter[*MockAWSItem, *MockClient, *MockOptions]{ + ItemType: "test-item", + cache: cache, + AccountID: "123456789012", + Region: "us-east-1", + GetFunc: func(ctx context.Context, client *MockClient, scope string, query string) (*MockAWSItem, error) { + return nil, errors.New("should not be called in SEARCH test") + }, + ListFunc: func(ctx context.Context, client *MockClient, scope string) ([]*MockAWSItem, error) { + return nil, errors.New("should not be called in SEARCH test") + }, + SearchFunc: func(ctx context.Context, client *MockClient, scope string, query string) ([]*MockAWSItem, error) { + searchCalls++ + // Return items that exist + return []*MockAWSItem{ + {Name: "item1"}, + {Name: "item2"}, + }, nil + }, + ItemMapper: func(query, scope string, awsItem *MockAWSItem) (*sdp.Item, error) { + // Simulate mapping failure for all items - this should NOT result in NOTFOUND caching + return nil, errors.New("mapping failed") + }, + AdapterMetadata: &sdp.AdapterMetadata{ + Type: "test-item", + DescriptiveName: "Test Item", + SupportedQueryMethods: &sdp.AdapterSupportedQueryMethods{ + Get: true, + List: true, + GetDescription: "Get a test item", + ListDescription: "List all test items", + }, + }, + } + + // First call - ItemMapper fails for all items, should NOT cache NOTFOUND + items1, err1 := adapter.SearchCustom(ctx, "123456789012.us-east-1", "test-query", false) + + if len(items1) != 0 { + t.Errorf("Expected 0 items (all mapping failed), got %d", len(items1)) + } + if err1 != nil { + t.Errorf("Expected nil error (errors are silently ignored via continue), got %v", err1) + } + if searchCalls != 1 { + t.Errorf("Expected 1 SearchFunc call, got %d", searchCalls) + } + + // Second call - should NOT hit cache (NOTFOUND was not cached), should try again + items2, err2 := adapter.SearchCustom(ctx, "123456789012.us-east-1", "test-query", false) + + if searchCalls != 2 { + t.Errorf("Expected 2 SearchFunc calls (no cache hit because NOTFOUND was not cached), got %d", searchCalls) + } + if len(items2) != 0 { + t.Errorf("Expected 0 items, got %d", len(items2)) + } + if err2 != nil { + t.Errorf("Expected nil error, got %v", err2) + } +} diff --git a/aws-source/adapters/adapterhelpers_util.go b/aws-source/adapters/adapterhelpers_util.go index 207af2ba..82c60cae 100644 --- a/aws-source/adapters/adapterhelpers_util.go +++ b/aws-source/adapters/adapterhelpers_util.go @@ -176,23 +176,43 @@ func ParseARN(arnString string) (*ARN, error) { }, nil } +// awsPartitionDNSSuffixes maps AWS partition names to their DNS suffixes. +// This is the single source of truth for all AWS partition DNS suffixes. +// See: https://docs.aws.amazon.com/general/latest/gr/rande.html +var awsPartitionDNSSuffixes = map[string]string{ + "aws": "amazonaws.com", + "aws-us-gov": "amazonaws.com", + "aws-cn": "amazonaws.com.cn", + "aws-iso": "c2s.ic.gov", + "aws-iso-b": "sc2s.sgov.gov", + "aws-eu": "amazonaws.eu", +} + // GetPartitionDNSSuffix returns the DNS suffix for a given AWS partition. // This is used to construct service URLs that work across all AWS partitions. func GetPartitionDNSSuffix(partition string) string { - switch partition { - case "aws-cn": - return "amazonaws.com.cn" - case "aws-iso": - return "c2s.ic.gov" - case "aws-iso-b": - return "sc2s.sgov.gov" - case "aws-eu": - return "amazonaws.eu" - case "aws", "aws-us-gov": - return "amazonaws.com" - default: - return "amazonaws.com" // Default to commercial partition + if suffix, ok := awsPartitionDNSSuffixes[partition]; ok { + return suffix } + return "amazonaws.com" // Default to commercial partition +} + +// GetAllAWSPartitionDNSSuffixes returns all known AWS partition DNS suffixes. +// This is useful for checking if a string (like a service principal) belongs +// to any AWS partition. +func GetAllAWSPartitionDNSSuffixes() []string { + // Use a map to deduplicate (aws and aws-us-gov share the same suffix) + seen := make(map[string]bool) + suffixes := make([]string, 0, len(awsPartitionDNSSuffixes)) + + for _, suffix := range awsPartitionDNSSuffixes { + if !seen[suffix] { + seen[suffix] = true + suffixes = append(suffixes, suffix) + } + } + + return suffixes } // WrapAWSError Wraps an AWS error in the appropriate SDP error @@ -337,6 +357,7 @@ func (e E2ETest) Run(t *testing.T) { if streamingAdapter, ok := e.Adapter.(discovery.ListStreamableAdapter); ok { stream := discovery.NewRecordingQueryResultStream() + streamingAdapter.ListStream(context.Background(), scope, false, stream) items = stream.GetItems() errs = stream.GetErrors() diff --git a/aws-source/adapters/apigateway-api-key_test.go b/aws-source/adapters/apigateway-api-key_test.go index 8bf0c41d..21e4243b 100644 --- a/aws-source/adapters/apigateway-api-key_test.go +++ b/aws-source/adapters/apigateway-api-key_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/apigateway" "github.com/aws/aws-sdk-go-v2/service/apigateway/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestApiKeyOutputMapper(t *testing.T) { @@ -47,7 +48,7 @@ func TestNewAPIGatewayApiKeyAdapter(t *testing.T) { client := apigateway.NewFromConfig(config) - adapter := NewAPIGatewayApiKeyAdapter(client, account, region, nil) + adapter := NewAPIGatewayApiKeyAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/apigateway-authorizer_test.go b/aws-source/adapters/apigateway-authorizer_test.go index aa6c9a4d..e7ff174e 100644 --- a/aws-source/adapters/apigateway-authorizer_test.go +++ b/aws-source/adapters/apigateway-authorizer_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/apigateway" "github.com/aws/aws-sdk-go-v2/service/apigateway/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestAuthorizerOutputMapper(t *testing.T) { @@ -50,7 +51,7 @@ func TestNewAPIGatewayAuthorizerAdapter(t *testing.T) { client := apigateway.NewFromConfig(config) - adapter := NewAPIGatewayAuthorizerAdapter(client, account, region, nil) + adapter := NewAPIGatewayAuthorizerAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/apigateway-deployment_test.go b/aws-source/adapters/apigateway-deployment_test.go index 70e7bc43..98ae4f60 100644 --- a/aws-source/adapters/apigateway-deployment_test.go +++ b/aws-source/adapters/apigateway-deployment_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/apigateway" "github.com/aws/aws-sdk-go-v2/service/apigateway/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestDeploymentOutputMapper(t *testing.T) { @@ -50,7 +51,7 @@ func TestNewAPIGatewayDeploymentAdapter(t *testing.T) { client := apigateway.NewFromConfig(config) - adapter := NewAPIGatewayDeploymentAdapter(client, account, region, nil) + adapter := NewAPIGatewayDeploymentAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/apigateway-domain-name.go b/aws-source/adapters/apigateway-domain-name.go index 7e7f733d..9b767355 100644 --- a/aws-source/adapters/apigateway-domain-name.go +++ b/aws-source/adapters/apigateway-domain-name.go @@ -187,7 +187,7 @@ func NewAPIGatewayDomainNameAdapter(client *apigateway.Client, accountID string, AccountID: accountID, Region: region, AdapterMetadata: apiGatewayDomainNameAdapterMetadata, - cache: cache, + cache: cache, GetFunc: func(ctx context.Context, client *apigateway.Client, scope, query string) (*types.DomainName, error) { if query == "" { return nil, &sdp.QueryError{ diff --git a/aws-source/adapters/apigateway-domain-name_test.go b/aws-source/adapters/apigateway-domain-name_test.go index f20d09b5..9a4f058d 100644 --- a/aws-source/adapters/apigateway-domain-name_test.go +++ b/aws-source/adapters/apigateway-domain-name_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/apigateway" "github.com/aws/aws-sdk-go-v2/service/apigateway/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) /* @@ -122,7 +123,7 @@ func TestNewAPIGatewayDomainNameAdapter(t *testing.T) { client := apigateway.NewFromConfig(config) - adapter := NewAPIGatewayDomainNameAdapter(client, account, region, nil) + adapter := NewAPIGatewayDomainNameAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/apigateway-integration.go b/aws-source/adapters/apigateway-integration.go index 52848746..bca58a04 100644 --- a/aws-source/adapters/apigateway-integration.go +++ b/aws-source/adapters/apigateway-integration.go @@ -95,7 +95,7 @@ func NewAPIGatewayIntegrationAdapter(client apiGatewayIntegrationGetter, account AccountID: accountID, Region: region, AdapterMetadata: apiGatewayIntegrationAdapterMetadata, - cache: cache, + cache: cache, GetFunc: apiGatewayIntegrationGetFunc, GetInputMapper: func(scope, query string) *apigateway.GetIntegrationInput { // We are using a custom id of {rest-api-id}/{resource-id}/{http-method} e.g. diff --git a/aws-source/adapters/apigateway-integration_test.go b/aws-source/adapters/apigateway-integration_test.go index 592728ed..efec09ee 100644 --- a/aws-source/adapters/apigateway-integration_test.go +++ b/aws-source/adapters/apigateway-integration_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/apigateway" "github.com/aws/aws-sdk-go-v2/service/apigateway/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) type mockAPIGatewayIntegrationClient struct{} @@ -77,7 +78,7 @@ func TestNewAPIGatewayIntegrationAdapter(t *testing.T) { client := apigateway.NewFromConfig(config) - adapter := NewAPIGatewayIntegrationAdapter(client, account, region, nil) + adapter := NewAPIGatewayIntegrationAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/apigateway-method-response.go b/aws-source/adapters/apigateway-method-response.go index ea8cab42..7e88cc03 100644 --- a/aws-source/adapters/apigateway-method-response.go +++ b/aws-source/adapters/apigateway-method-response.go @@ -73,7 +73,7 @@ func NewAPIGatewayMethodResponseAdapter(client apigatewayClient, accountID strin AccountID: accountID, Region: region, AdapterMetadata: apiGatewayMethodResponseAdapterMetadata, - cache: cache, + cache: cache, GetFunc: apiGatewayMethodResponseGetFunc, GetInputMapper: func(scope, query string) *apigateway.GetMethodResponseInput { // We are using a custom id of {rest-api-id}/{resource-id}/{http-method}/{status-code} e.g. diff --git a/aws-source/adapters/apigateway-method-response_test.go b/aws-source/adapters/apigateway-method-response_test.go index 665a62c2..6c351c4f 100644 --- a/aws-source/adapters/apigateway-method-response_test.go +++ b/aws-source/adapters/apigateway-method-response_test.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/apigateway" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (m *mockAPIGatewayClient) GetMethodResponse(ctx context.Context, params *apigateway.GetMethodResponseInput, optFns ...func(*apigateway.Options)) (*apigateway.GetMethodResponseOutput, error) { @@ -59,7 +60,7 @@ func TestNewAPIGatewayMethodResponseAdapter(t *testing.T) { client := apigateway.NewFromConfig(config) - adapter := NewAPIGatewayMethodResponseAdapter(client, account, region, nil) + adapter := NewAPIGatewayMethodResponseAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/apigateway-method.go b/aws-source/adapters/apigateway-method.go index b1f30bc8..997af413 100644 --- a/aws-source/adapters/apigateway-method.go +++ b/aws-source/adapters/apigateway-method.go @@ -131,7 +131,7 @@ func NewAPIGatewayMethodAdapter(client apigatewayClient, accountID string, regio AccountID: accountID, Region: region, AdapterMetadata: apiGatewayMethodAdapterMetadata, - cache: cache, + cache: cache, GetFunc: apiGatewayMethodGetFunc, GetInputMapper: func(scope, query string) *apigateway.GetMethodInput { // We are using a custom id of {rest-api-id}/{resource-id}/{http-method} e.g. diff --git a/aws-source/adapters/apigateway-method_test.go b/aws-source/adapters/apigateway-method_test.go index ce3ab28e..8e71e819 100644 --- a/aws-source/adapters/apigateway-method_test.go +++ b/aws-source/adapters/apigateway-method_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/apigateway" "github.com/aws/aws-sdk-go-v2/service/apigateway/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) type mockAPIGatewayClient struct{} @@ -101,7 +102,7 @@ func TestNewAPIGatewayMethodAdapter(t *testing.T) { client := apigateway.NewFromConfig(config) - adapter := NewAPIGatewayMethodAdapter(client, account, region, nil) + adapter := NewAPIGatewayMethodAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/apigateway-model_test.go b/aws-source/adapters/apigateway-model_test.go index 3ba69e18..33f2d343 100644 --- a/aws-source/adapters/apigateway-model_test.go +++ b/aws-source/adapters/apigateway-model_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/apigateway" "github.com/aws/aws-sdk-go-v2/service/apigateway/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestModelOutputMapper(t *testing.T) { @@ -45,7 +46,7 @@ func TestNewAPIGatewayModelAdapter(t *testing.T) { client := apigateway.NewFromConfig(config) - adapter := NewAPIGatewayModelAdapter(client, account, region, nil) + adapter := NewAPIGatewayModelAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/apigateway-resource_test.go b/aws-source/adapters/apigateway-resource_test.go index 20b8a786..a6bac53c 100644 --- a/aws-source/adapters/apigateway-resource_test.go +++ b/aws-source/adapters/apigateway-resource_test.go @@ -6,6 +6,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/apigateway" "github.com/aws/aws-sdk-go-v2/service/apigateway/types" + "github.com/overmindtech/cli/sdpcache" ) /* @@ -166,7 +167,7 @@ func TestNewAPIGatewayResourceAdapter(t *testing.T) { client := apigateway.NewFromConfig(config) - adapter := NewAPIGatewayResourceAdapter(client, account, region, nil) + adapter := NewAPIGatewayResourceAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/apigateway-rest-api_test.go b/aws-source/adapters/apigateway-rest-api_test.go index ff83c09e..f53a0468 100644 --- a/aws-source/adapters/apigateway-rest-api_test.go +++ b/aws-source/adapters/apigateway-rest-api_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/apigateway/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) /* @@ -120,7 +121,7 @@ func TestNewAPIGatewayRestApiAdapter(t *testing.T) { client := apigateway.NewFromConfig(config) - adapter := NewAPIGatewayRestApiAdapter(client, account, region, nil) + adapter := NewAPIGatewayRestApiAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/apigateway-stage_test.go b/aws-source/adapters/apigateway-stage_test.go index e4a0e4e0..c852377c 100644 --- a/aws-source/adapters/apigateway-stage_test.go +++ b/aws-source/adapters/apigateway-stage_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/apigateway" "github.com/aws/aws-sdk-go-v2/service/apigateway/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestStageOutputMapper(t *testing.T) { @@ -66,7 +67,7 @@ func TestNewAPIGatewayStageAdapter(t *testing.T) { client := apigateway.NewFromConfig(config) - adapter := NewAPIGatewayStageAdapter(client, account, region, nil) + adapter := NewAPIGatewayStageAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/autoscaling-auto-scaling-group_test.go b/aws-source/adapters/autoscaling-auto-scaling-group_test.go index 665f6a31..ea3a4c79 100644 --- a/aws-source/adapters/autoscaling-auto-scaling-group_test.go +++ b/aws-source/adapters/autoscaling-auto-scaling-group_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/autoscaling" "github.com/aws/aws-sdk-go-v2/service/autoscaling/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestAutoScalingGroupOutputMapper(t *testing.T) { @@ -226,7 +227,7 @@ func TestAutoScalingGroupOutputMapper(t *testing.T) { func TestAutoScalingGroupInputMapperSearch(t *testing.T) { t.Parallel() - adapter := NewAutoScalingGroupAdapter(&autoscaling.Client{}, "123456789012", "us-east-1", nil) + adapter := NewAutoScalingGroupAdapter(&autoscaling.Client{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) tests := []struct { name string diff --git a/aws-source/adapters/autoscaling-auto-scaling-policy_test.go b/aws-source/adapters/autoscaling-auto-scaling-policy_test.go index 864d55f3..27762c28 100644 --- a/aws-source/adapters/autoscaling-auto-scaling-policy_test.go +++ b/aws-source/adapters/autoscaling-auto-scaling-policy_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/autoscaling" "github.com/aws/aws-sdk-go-v2/service/autoscaling/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestScalingPolicyOutputMapper(t *testing.T) { @@ -360,7 +361,7 @@ func TestParseResourceLabelLinks(t *testing.T) { func TestScalingPolicyInputMapperSearch(t *testing.T) { t.Parallel() - adapter := NewAutoScalingPolicyAdapter(&autoscaling.Client{}, "123456789012", "us-east-1", nil) + adapter := NewAutoScalingPolicyAdapter(&autoscaling.Client{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) tests := []struct { name string @@ -464,7 +465,7 @@ func TestScalingPolicyInputMapperSearch(t *testing.T) { func TestScalingPolicyInputMapperGet(t *testing.T) { t.Parallel() - adapter := NewAutoScalingPolicyAdapter(&autoscaling.Client{}, "123456789012", "us-east-1", nil) + adapter := NewAutoScalingPolicyAdapter(&autoscaling.Client{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) tests := []struct { name string diff --git a/aws-source/adapters/cloudfront-cache-policy_test.go b/aws-source/adapters/cloudfront-cache-policy_test.go index 681b97d8..db2b8cac 100644 --- a/aws-source/adapters/cloudfront-cache-policy_test.go +++ b/aws-source/adapters/cloudfront-cache-policy_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/cloudfront" "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" + "github.com/overmindtech/cli/sdpcache" ) var testCachePolicy = &types.CachePolicy{ @@ -86,7 +87,7 @@ func TestCachePolicyListFunc(t *testing.T) { func TestNewCloudfrontCachePolicyAdapter(t *testing.T) { client, account, _ := CloudfrontGetAutoConfig(t) - adapter := NewCloudfrontCachePolicyAdapter(client, account, nil) + adapter := NewCloudfrontCachePolicyAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/cloudfront-continuous-deployment-policy_test.go b/aws-source/adapters/cloudfront-continuous-deployment-policy_test.go index ffde6dc9..ad4f2576 100644 --- a/aws-source/adapters/cloudfront-continuous-deployment-policy_test.go +++ b/aws-source/adapters/cloudfront-continuous-deployment-policy_test.go @@ -6,6 +6,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestContinuousDeploymentPolicyItemMapper(t *testing.T) { @@ -60,7 +61,7 @@ func TestContinuousDeploymentPolicyItemMapper(t *testing.T) { func TestNewCloudfrontContinuousDeploymentPolicyAdapter(t *testing.T) { client, account, _ := CloudfrontGetAutoConfig(t) - adapter := NewCloudfrontContinuousDeploymentPolicyAdapter(client, account, nil) + adapter := NewCloudfrontContinuousDeploymentPolicyAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/cloudfront-distribution.go b/aws-source/adapters/cloudfront-distribution.go index 71c704bf..78422672 100644 --- a/aws-source/adapters/cloudfront-distribution.go +++ b/aws-source/adapters/cloudfront-distribution.go @@ -637,7 +637,7 @@ func NewCloudfrontDistributionAdapter(client CloudFrontClient, accountID string, Client: client, AccountID: accountID, AdapterMetadata: distributionAdapterMetadata, - cache: cache, + cache: cache, Region: "", // Cloudfront resources aren't tied to a region ListInput: &cloudfront.ListDistributionsInput{}, ListFuncPaginatorBuilder: func(client CloudFrontClient, input *cloudfront.ListDistributionsInput) Paginator[*cloudfront.ListDistributionsOutput, *cloudfront.Options] { diff --git a/aws-source/adapters/cloudfront-distribution_test.go b/aws-source/adapters/cloudfront-distribution_test.go index 5e5b5c65..4a35cab8 100644 --- a/aws-source/adapters/cloudfront-distribution_test.go +++ b/aws-source/adapters/cloudfront-distribution_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/cloudfront" "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t TestCloudFrontClient) GetDistribution(ctx context.Context, params *cloudfront.GetDistributionInput, optFns ...func(*cloudfront.Options)) (*cloudfront.GetDistributionOutput, error) { @@ -498,7 +499,7 @@ func TestNewCloudfrontDistributionAdapter(t *testing.T) { config, account, _ := GetAutoConfig(t) client := cloudfront.NewFromConfig(config) - adapter := NewCloudfrontDistributionAdapter(client, account, nil) + adapter := NewCloudfrontDistributionAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/cloudfront-function_test.go b/aws-source/adapters/cloudfront-function_test.go index 52e2ea00..a1c903f0 100644 --- a/aws-source/adapters/cloudfront-function_test.go +++ b/aws-source/adapters/cloudfront-function_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" + "github.com/overmindtech/cli/sdpcache" ) func TestFunctionItemMapper(t *testing.T) { @@ -37,7 +38,7 @@ func TestFunctionItemMapper(t *testing.T) { func TestNewCloudfrontCloudfrontFunctionAdapter(t *testing.T) { client, account, _ := CloudfrontGetAutoConfig(t) - adapter := NewCloudfrontCloudfrontFunctionAdapter(client, account, nil) + adapter := NewCloudfrontCloudfrontFunctionAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/cloudfront-key-group_test.go b/aws-source/adapters/cloudfront-key-group_test.go index f366988a..2e9defe0 100644 --- a/aws-source/adapters/cloudfront-key-group_test.go +++ b/aws-source/adapters/cloudfront-key-group_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" + "github.com/overmindtech/cli/sdpcache" ) func TestKeyGroupItemMapper(t *testing.T) { @@ -34,7 +35,7 @@ func TestKeyGroupItemMapper(t *testing.T) { func TestNewCloudfrontKeyGroupAdapter(t *testing.T) { client, account, _ := CloudfrontGetAutoConfig(t) - adapter := NewCloudfrontKeyGroupAdapter(client, account, nil) + adapter := NewCloudfrontKeyGroupAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/cloudfront-origin-access-control_test.go b/aws-source/adapters/cloudfront-origin-access-control_test.go index 5c4e1964..52ce8711 100644 --- a/aws-source/adapters/cloudfront-origin-access-control_test.go +++ b/aws-source/adapters/cloudfront-origin-access-control_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" + "github.com/overmindtech/cli/sdpcache" ) func TestOriginAccessControlItemMapper(t *testing.T) { @@ -33,7 +34,7 @@ func TestOriginAccessControlItemMapper(t *testing.T) { func TestNewCloudfrontOriginAccessControlAdapter(t *testing.T) { client, account, _ := CloudfrontGetAutoConfig(t) - adapter := NewCloudfrontOriginAccessControlAdapter(client, account, nil) + adapter := NewCloudfrontOriginAccessControlAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/cloudfront-origin-request-policy_test.go b/aws-source/adapters/cloudfront-origin-request-policy_test.go index 33603f34..fb00556d 100644 --- a/aws-source/adapters/cloudfront-origin-request-policy_test.go +++ b/aws-source/adapters/cloudfront-origin-request-policy_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" + "github.com/overmindtech/cli/sdpcache" ) func TestOriginRequestPolicyItemMapper(t *testing.T) { @@ -52,7 +53,7 @@ func TestOriginRequestPolicyItemMapper(t *testing.T) { func TestNewCloudfrontOriginRequestPolicyAdapter(t *testing.T) { client, account, _ := CloudfrontGetAutoConfig(t) - adapter := NewCloudfrontOriginRequestPolicyAdapter(client, account, nil) + adapter := NewCloudfrontOriginRequestPolicyAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/cloudfront-realtime-log-config_test.go b/aws-source/adapters/cloudfront-realtime-log-config_test.go index 48c5ef2f..23ebbe9e 100644 --- a/aws-source/adapters/cloudfront-realtime-log-config_test.go +++ b/aws-source/adapters/cloudfront-realtime-log-config_test.go @@ -6,6 +6,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestRealtimeLogConfigsItemMapper(t *testing.T) { @@ -58,7 +59,7 @@ func TestRealtimeLogConfigsItemMapper(t *testing.T) { func TestNewCloudfrontRealtimeLogConfigsAdapter(t *testing.T) { client, account, _ := CloudfrontGetAutoConfig(t) - adapter := NewCloudfrontRealtimeLogConfigsAdapter(client, account, nil) + adapter := NewCloudfrontRealtimeLogConfigsAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/cloudfront-response-headers-policy_test.go b/aws-source/adapters/cloudfront-response-headers-policy_test.go index ee30a01b..9557d875 100644 --- a/aws-source/adapters/cloudfront-response-headers-policy_test.go +++ b/aws-source/adapters/cloudfront-response-headers-policy_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" + "github.com/overmindtech/cli/sdpcache" ) func TestResponseHeadersPolicyItemMapper(t *testing.T) { @@ -89,7 +90,7 @@ func TestResponseHeadersPolicyItemMapper(t *testing.T) { func TestNewCloudfrontResponseHeadersPolicyAdapter(t *testing.T) { client, account, _ := CloudfrontGetAutoConfig(t) - adapter := NewCloudfrontResponseHeadersPolicyAdapter(client, account, nil) + adapter := NewCloudfrontResponseHeadersPolicyAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/cloudfront-streaming-distribution.go b/aws-source/adapters/cloudfront-streaming-distribution.go index cbb81fdf..4ea0525e 100644 --- a/aws-source/adapters/cloudfront-streaming-distribution.go +++ b/aws-source/adapters/cloudfront-streaming-distribution.go @@ -163,7 +163,7 @@ func NewCloudfrontStreamingDistributionAdapter(client CloudFrontClient, accountI AccountID: accountID, Region: "", // Cloudfront resources aren't tied to a region AdapterMetadata: streamingDistributionAdapterMetadata, - cache: cache, + cache: cache, ListInput: &cloudfront.ListStreamingDistributionsInput{}, ListFuncPaginatorBuilder: func(client CloudFrontClient, input *cloudfront.ListStreamingDistributionsInput) Paginator[*cloudfront.ListStreamingDistributionsOutput, *cloudfront.Options] { return cloudfront.NewListStreamingDistributionsPaginator(client, input) diff --git a/aws-source/adapters/cloudfront-streaming-distribution_test.go b/aws-source/adapters/cloudfront-streaming-distribution_test.go index 721067f3..6ae8ff7c 100644 --- a/aws-source/adapters/cloudfront-streaming-distribution_test.go +++ b/aws-source/adapters/cloudfront-streaming-distribution_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/cloudfront" "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t TestCloudFrontClient) GetStreamingDistribution(ctx context.Context, params *cloudfront.GetStreamingDistributionInput, optFns ...func(*cloudfront.Options)) (*cloudfront.GetStreamingDistributionOutput, error) { @@ -110,7 +111,7 @@ func TestNewCloudfrontStreamingDistributionAdapter(t *testing.T) { config, account, _ := GetAutoConfig(t) client := cloudfront.NewFromConfig(config) - adapter := NewCloudfrontStreamingDistributionAdapter(client, account, nil) + adapter := NewCloudfrontStreamingDistributionAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/cloudwatch-alarm_test.go b/aws-source/adapters/cloudwatch-alarm_test.go index 13ba9f31..929e2269 100644 --- a/aws-source/adapters/cloudwatch-alarm_test.go +++ b/aws-source/adapters/cloudwatch-alarm_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/cloudwatch/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) type testCloudwatchClient struct{} @@ -218,7 +219,7 @@ func TestNewCloudwatchAlarmAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := cloudwatch.NewFromConfig(config) - adapter := NewCloudwatchAlarmAdapter(client, account, region, nil) + adapter := NewCloudwatchAlarmAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/cloudwatch-instance-metric.go b/aws-source/adapters/cloudwatch-instance-metric.go index 55c69aa4..9823da80 100644 --- a/aws-source/adapters/cloudwatch-instance-metric.go +++ b/aws-source/adapters/cloudwatch-instance-metric.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "regexp" - "sync" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -197,21 +196,6 @@ func (a *CloudwatchInstanceMetricAdapter) cacheDuration() time.Duration { return a.CacheDuration } -var ( - noOpCacheCloudwatchOnce sync.Once - noOpCacheCloudwatch sdpcache.Cache -) - -func (a *CloudwatchInstanceMetricAdapter) Cache() sdpcache.Cache { - if a.cache == nil { - noOpCacheCloudwatchOnce.Do(func() { - noOpCacheCloudwatch = sdpcache.NewNoOpCache() - }) - return noOpCacheCloudwatch - } - return a.cache -} - // Type returns the type of items this adapter returns func (a *CloudwatchInstanceMetricAdapter) Type() string { return "cloudwatch-instance-metric" @@ -258,7 +242,7 @@ func (a *CloudwatchInstanceMetricAdapter) Get(ctx context.Context, scope string, var cachedItems []*sdp.Item var qErr *sdp.QueryError - cacheHit, ck, cachedItems, qErr, done := a.Cache().Lookup(ctx, a.Name(), sdp.QueryMethod_GET, scope, a.Type(), query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done := a.cache.Lookup(ctx, a.Name(), sdp.QueryMethod_GET, scope, a.Type(), query, ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -309,7 +293,7 @@ func (a *CloudwatchInstanceMetricAdapter) Get(ctx context.Context, scope string, Scope: scope, } // Cache the error - a.Cache().StoreError(ctx, qErr, a.cacheDuration(), ck) + a.cache.StoreError(ctx, qErr, a.cacheDuration(), ck) return nil, qErr } @@ -321,12 +305,12 @@ func (a *CloudwatchInstanceMetricAdapter) Get(ctx context.Context, scope string, Scope: scope, } // Cache the error - a.Cache().StoreError(ctx, qErr, a.cacheDuration(), ck) + a.cache.StoreError(ctx, qErr, a.cacheDuration(), ck) return nil, qErr } // Store in cache - a.Cache().StoreItem(ctx, item, a.cacheDuration(), ck) + a.cache.StoreItem(ctx, item, a.cacheDuration(), ck) return item, nil } @@ -358,7 +342,7 @@ func NewCloudwatchInstanceMetricAdapter(client *cloudwatch.Client, accountID str Client: client, AccountID: accountID, Region: region, - cache: cache, + cache: cache, } } diff --git a/aws-source/adapters/cloudwatch-instance-metric_integration_test.go b/aws-source/adapters/cloudwatch-instance-metric_integration_test.go index 19370a02..24f6598b 100644 --- a/aws-source/adapters/cloudwatch-instance-metric_integration_test.go +++ b/aws-source/adapters/cloudwatch-instance-metric_integration_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/aws/aws-sdk-go-v2/service/cloudwatch" + "github.com/overmindtech/cli/sdpcache" ) // TestCloudwatchInstanceMetricIntegration fetches real CloudWatch metrics for an EC2 instance @@ -22,7 +23,7 @@ func TestCloudwatchInstanceMetricIntegration(t *testing.T) { config, account, region := GetAutoConfig(t) client := cloudwatch.NewFromConfig(config) - adapter := NewCloudwatchInstanceMetricAdapter(client, account, region, nil) + adapter := NewCloudwatchInstanceMetricAdapter(client, account, region, sdpcache.NewNoOpCache()) scope := FormatScope(account, region) // Query is just the instance ID diff --git a/aws-source/adapters/cloudwatch-instance-metric_test.go b/aws-source/adapters/cloudwatch-instance-metric_test.go index f7849147..e3aefc1f 100644 --- a/aws-source/adapters/cloudwatch-instance-metric_test.go +++ b/aws-source/adapters/cloudwatch-instance-metric_test.go @@ -266,6 +266,7 @@ func TestCloudwatchInstanceMetricAdapterGet(t *testing.T) { Client: testCloudwatchMetricClient{}, AccountID: "123456789012", Region: "eu-west-2", + cache: sdpcache.NewNoOpCache(), } scope := "123456789012.eu-west-2" @@ -298,6 +299,7 @@ func TestCloudwatchInstanceMetricAdapterGetWrongScope(t *testing.T) { Client: testCloudwatchMetricClient{}, AccountID: "123456789012", Region: "eu-west-2", + cache: sdpcache.NewNoOpCache(), } wrongScope := "999999999999.us-east-1" @@ -314,6 +316,7 @@ func TestCloudwatchInstanceMetricAdapterGetInvalidQuery(t *testing.T) { Client: testCloudwatchMetricClient{}, AccountID: "123456789012", Region: "eu-west-2", + cache: sdpcache.NewNoOpCache(), } scope := "123456789012.eu-west-2" @@ -343,6 +346,7 @@ func TestCloudwatchInstanceMetricAdapterList(t *testing.T) { Client: testCloudwatchMetricClient{}, AccountID: "123456789012", Region: "eu-west-2", + cache: sdpcache.NewNoOpCache(), } scope := "123456789012.eu-west-2" @@ -394,7 +398,7 @@ func TestNewCloudwatchInstanceMetricAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := cloudwatch.NewFromConfig(config) - adapter := NewCloudwatchInstanceMetricAdapter(client, account, region, nil) + adapter := NewCloudwatchInstanceMetricAdapter(client, account, region, sdpcache.NewNoOpCache()) if adapter.Type() != "cloudwatch-instance-metric" { t.Errorf("expected type cloudwatch-instance-metric, got %s", adapter.Type()) @@ -406,13 +410,12 @@ func TestNewCloudwatchInstanceMetricAdapter(t *testing.T) { } func TestCloudwatchInstanceMetricAdapterCaching(t *testing.T) { - ctx := t.Context() client := &testCloudwatchMetricClientWithCallCount{} adapter := &CloudwatchInstanceMetricAdapter{ Client: client, AccountID: "123456789012", Region: "eu-west-2", - cache: sdpcache.NewCache(ctx), + cache: sdpcache.NewMemoryCache(), } scope := "123456789012.eu-west-2" @@ -458,6 +461,7 @@ func TestCloudwatchInstanceMetricAdapterIgnoreCache(t *testing.T) { Client: client, AccountID: "123456789012", Region: "eu-west-2", + cache: sdpcache.NewNoOpCache(), } scope := "123456789012.eu-west-2" @@ -495,6 +499,7 @@ func TestCloudwatchInstanceMetricAdapterErrorCaching(t *testing.T) { Client: testCloudwatchMetricClientError{}, AccountID: "123456789012", Region: "eu-west-2", + cache: sdpcache.NewNoOpCache(), } scope := "123456789012.eu-west-2" diff --git a/aws-source/adapters/directconnect-connection_test.go b/aws-source/adapters/directconnect-connection_test.go index 3bff2740..c4b5e264 100644 --- a/aws-source/adapters/directconnect-connection_test.go +++ b/aws-source/adapters/directconnect-connection_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/directconnect" "github.com/aws/aws-sdk-go-v2/service/directconnect/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestDirectconnectConnectionOutputMapper(t *testing.T) { @@ -90,7 +91,7 @@ func TestDirectconnectConnectionOutputMapper(t *testing.T) { func TestNewDirectConnectConnectionAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectConnectionAdapter(client, account, region, nil) + adapter := NewDirectConnectConnectionAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/directconnect-customer-metadata_test.go b/aws-source/adapters/directconnect-customer-metadata_test.go index 090a9372..491a2b13 100644 --- a/aws-source/adapters/directconnect-customer-metadata_test.go +++ b/aws-source/adapters/directconnect-customer-metadata_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/directconnect" "github.com/aws/aws-sdk-go-v2/service/directconnect/types" + "github.com/overmindtech/cli/sdpcache" ) func TestCustomerMetadataOutputMapper(t *testing.T) { @@ -38,7 +39,7 @@ func TestCustomerMetadataOutputMapper(t *testing.T) { func TestNewDirectConnectCustomerMetadataAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectCustomerMetadataAdapter(client, account, region, nil) + adapter := NewDirectConnectCustomerMetadataAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/directconnect-direct-connect-gateway-association-proposal_test.go b/aws-source/adapters/directconnect-direct-connect-gateway-association-proposal_test.go index 6882421e..5c836b06 100644 --- a/aws-source/adapters/directconnect-direct-connect-gateway-association-proposal_test.go +++ b/aws-source/adapters/directconnect-direct-connect-gateway-association-proposal_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/directconnect/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestDirectConnectGatewayAssociationProposalOutputMapper(t *testing.T) { @@ -75,7 +76,7 @@ func TestDirectConnectGatewayAssociationProposalOutputMapper(t *testing.T) { func TestNewDirectConnectGatewayAssociationProposalAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectGatewayAssociationProposalAdapter(client, account, region, nil) + adapter := NewDirectConnectGatewayAssociationProposalAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/directconnect-direct-connect-gateway-association_test.go b/aws-source/adapters/directconnect-direct-connect-gateway-association_test.go index c4b1c33f..f46e8b0c 100644 --- a/aws-source/adapters/directconnect-direct-connect-gateway-association_test.go +++ b/aws-source/adapters/directconnect-direct-connect-gateway-association_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/directconnect" "github.com/aws/aws-sdk-go-v2/service/directconnect/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestDirectConnectGatewayAssociationOutputMapper_Health_OK(t *testing.T) { @@ -118,7 +119,7 @@ func TestDirectConnectGatewayAssociationOutputMapper_Health_Error(t *testing.T) func TestNewDirectConnectGatewayAssociationAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectGatewayAssociationAdapter(client, account, region, nil) + adapter := NewDirectConnectGatewayAssociationAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/directconnect-direct-connect-gateway-attachment_test.go b/aws-source/adapters/directconnect-direct-connect-gateway-attachment_test.go index b94970b9..1b8ad52d 100644 --- a/aws-source/adapters/directconnect-direct-connect-gateway-attachment_test.go +++ b/aws-source/adapters/directconnect-direct-connect-gateway-attachment_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/directconnect" "github.com/aws/aws-sdk-go-v2/service/directconnect/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestDirectConnectGatewayAttachmentOutputMapper_Health_OK(t *testing.T) { @@ -118,7 +119,7 @@ func TestDirectConnectGatewayAttachmentOutputMapper_Health_Error(t *testing.T) { func TestNewDirectConnectGatewayAttachmentAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectGatewayAttachmentAdapter(client, account, region, nil) + adapter := NewDirectConnectGatewayAttachmentAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/directconnect-direct-connect-gateway_test.go b/aws-source/adapters/directconnect-direct-connect-gateway_test.go index 36ab6058..c70f9f36 100644 --- a/aws-source/adapters/directconnect-direct-connect-gateway_test.go +++ b/aws-source/adapters/directconnect-direct-connect-gateway_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/directconnect" "github.com/aws/aws-sdk-go-v2/service/directconnect/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestDirectConnectGatewayOutputMapper_Health_OK(t *testing.T) { @@ -80,7 +81,7 @@ func TestDirectConnectGatewayOutputMapper_Health_ERROR(t *testing.T) { func TestNewDirectConnectGatewayAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectGatewayAdapter(client, account, region, nil) + adapter := NewDirectConnectGatewayAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/directconnect-hosted-connection_test.go b/aws-source/adapters/directconnect-hosted-connection_test.go index fc9e9156..1ef0f8e2 100644 --- a/aws-source/adapters/directconnect-hosted-connection_test.go +++ b/aws-source/adapters/directconnect-hosted-connection_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/directconnect" "github.com/aws/aws-sdk-go-v2/service/directconnect/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestHostedConnectionOutputMapper(t *testing.T) { @@ -90,7 +91,7 @@ func TestHostedConnectionOutputMapper(t *testing.T) { func TestNewDirectConnectHostedConnectionAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectHostedConnectionAdapter(client, account, region, nil) + adapter := NewDirectConnectHostedConnectionAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/directconnect-interconnect_test.go b/aws-source/adapters/directconnect-interconnect_test.go index 3922012c..108f15d5 100644 --- a/aws-source/adapters/directconnect-interconnect_test.go +++ b/aws-source/adapters/directconnect-interconnect_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/directconnect" "github.com/aws/aws-sdk-go-v2/service/directconnect/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestInterconnectOutputMapper(t *testing.T) { @@ -149,7 +150,7 @@ func TestInterconnectHealth(t *testing.T) { func TestNewDirectConnectInterconnectAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectInterconnectAdapter(client, account, region, nil) + adapter := NewDirectConnectInterconnectAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/directconnect-lag_test.go b/aws-source/adapters/directconnect-lag_test.go index 48fd4e36..95b0815c 100644 --- a/aws-source/adapters/directconnect-lag_test.go +++ b/aws-source/adapters/directconnect-lag_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" "github.com/aws/aws-sdk-go-v2/service/directconnect" "github.com/aws/aws-sdk-go-v2/service/directconnect/types" @@ -167,7 +168,7 @@ func TestLagOutputMapper(t *testing.T) { func TestNewDirectConnectLagAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectLagAdapter(client, account, region, nil) + adapter := NewDirectConnectLagAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/directconnect-location_test.go b/aws-source/adapters/directconnect-location_test.go index 16b484d5..bd36c612 100644 --- a/aws-source/adapters/directconnect-location_test.go +++ b/aws-source/adapters/directconnect-location_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/directconnect" "github.com/aws/aws-sdk-go-v2/service/directconnect/types" + "github.com/overmindtech/cli/sdpcache" ) func TestLocationOutputMapper(t *testing.T) { @@ -42,7 +43,7 @@ func TestLocationOutputMapper(t *testing.T) { func TestNewDirectConnectLocationAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectLocationAdapter(client, account, region, nil) + adapter := NewDirectConnectLocationAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/directconnect-router-configuration_test.go b/aws-source/adapters/directconnect-router-configuration_test.go index d6d9c48e..8a124b6e 100644 --- a/aws-source/adapters/directconnect-router-configuration_test.go +++ b/aws-source/adapters/directconnect-router-configuration_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/directconnect" "github.com/aws/aws-sdk-go-v2/service/directconnect/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestRouterConfigurationOutputMapper(t *testing.T) { @@ -57,7 +58,7 @@ func TestRouterConfigurationOutputMapper(t *testing.T) { func TestNewDirectConnectRouterConfigurationAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectRouterConfigurationAdapter(client, account, region, nil) + adapter := NewDirectConnectRouterConfigurationAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/directconnect-virtual-gateway_test.go b/aws-source/adapters/directconnect-virtual-gateway_test.go index 618125d1..2f6252b9 100644 --- a/aws-source/adapters/directconnect-virtual-gateway_test.go +++ b/aws-source/adapters/directconnect-virtual-gateway_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/directconnect" "github.com/aws/aws-sdk-go-v2/service/directconnect/types" + "github.com/overmindtech/cli/sdpcache" ) func TestVirtualGatewayOutputMapper(t *testing.T) { @@ -38,7 +39,7 @@ func TestVirtualGatewayOutputMapper(t *testing.T) { func TestNewDirectConnectVirtualGatewayAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectVirtualGatewayAdapter(client, account, region, nil) + adapter := NewDirectConnectVirtualGatewayAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/directconnect-virtual-interface_test.go b/aws-source/adapters/directconnect-virtual-interface_test.go index 5a5a3e17..6a5463e8 100644 --- a/aws-source/adapters/directconnect-virtual-interface_test.go +++ b/aws-source/adapters/directconnect-virtual-interface_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/directconnect/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestVirtualInterfaceOutputMapper(t *testing.T) { @@ -90,7 +91,7 @@ func TestVirtualInterfaceOutputMapper(t *testing.T) { func TestNewDirectConnectVirtualInterfaceAdapter(t *testing.T) { client, account, region := directconnectGetAutoConfig(t) - adapter := NewDirectConnectVirtualInterfaceAdapter(client, account, region, nil) + adapter := NewDirectConnectVirtualInterfaceAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/dynamodb-backup.go b/aws-source/adapters/dynamodb-backup.go index e2b677da..cf67c84d 100644 --- a/aws-source/adapters/dynamodb-backup.go +++ b/aws-source/adapters/dynamodb-backup.go @@ -83,7 +83,7 @@ func NewDynamoDBBackupAdapter(client Client, accountID string, region string, ca GetFunc: backupGetFunc, ListInput: &dynamodb.ListBackupsInput{}, AdapterMetadata: dynamodbBackupAdapterMetadata, - cache: cache, + cache: cache, GetInputMapper: func(scope, query string) *dynamodb.DescribeBackupInput { // Get is not supported since you can't search by name return nil diff --git a/aws-source/adapters/dynamodb-backup_test.go b/aws-source/adapters/dynamodb-backup_test.go index 36596fd9..e4f4b718 100644 --- a/aws-source/adapters/dynamodb-backup_test.go +++ b/aws-source/adapters/dynamodb-backup_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t *DynamoDBTestClient) DescribeBackup(ctx context.Context, params *dynamodb.DescribeBackupInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeBackupOutput, error) { @@ -113,7 +114,7 @@ func TestNewDynamoDBBackupAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := dynamodb.NewFromConfig(config) - adapter := NewDynamoDBBackupAdapter(client, account, region, nil) + adapter := NewDynamoDBBackupAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/dynamodb-table.go b/aws-source/adapters/dynamodb-table.go index 0f43ca1d..8aa3c8ad 100644 --- a/aws-source/adapters/dynamodb-table.go +++ b/aws-source/adapters/dynamodb-table.go @@ -174,7 +174,7 @@ func NewDynamoDBTableAdapter(client Client, accountID string, region string, cac GetFunc: tableGetFunc, ListInput: &dynamodb.ListTablesInput{}, AdapterMetadata: dynamodbTableAdapterMetadata, - cache: cache, + cache: cache, GetInputMapper: func(scope, query string) *dynamodb.DescribeTableInput { return &dynamodb.DescribeTableInput{ TableName: &query, diff --git a/aws-source/adapters/dynamodb-table_test.go b/aws-source/adapters/dynamodb-table_test.go index 0a9e1116..13589e01 100644 --- a/aws-source/adapters/dynamodb-table_test.go +++ b/aws-source/adapters/dynamodb-table_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/dynamodb" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t *DynamoDBTestClient) DescribeTable(context.Context, *dynamodb.DescribeTableInput, ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) { @@ -223,7 +224,7 @@ func TestNewDynamoDBTableAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := dynamodb.NewFromConfig(config) - adapter := NewDynamoDBTableAdapter(client, account, region, nil) + adapter := NewDynamoDBTableAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-address_test.go b/aws-source/adapters/ec2-address_test.go index 4c523b07..3132a84f 100644 --- a/aws-source/adapters/ec2-address_test.go +++ b/aws-source/adapters/ec2-address_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestAddressInputMapperGet(t *testing.T) { @@ -121,7 +122,7 @@ func TestAddressOutputMapper(t *testing.T) { func TestNewEC2AddressAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2AddressAdapter(client, account, region, nil) + adapter := NewEC2AddressAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-capacity-reservation-fleet_test.go b/aws-source/adapters/ec2-capacity-reservation-fleet_test.go index 0b4fb905..48bf1419 100644 --- a/aws-source/adapters/ec2-capacity-reservation-fleet_test.go +++ b/aws-source/adapters/ec2-capacity-reservation-fleet_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/overmindtech/cli/sdpcache" ) func TestCapacityReservationFleetOutputMapper(t *testing.T) { @@ -71,7 +72,7 @@ func TestCapacityReservationFleetOutputMapper(t *testing.T) { func TestNewEC2CapacityReservationFleetAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2CapacityReservationFleetAdapter(client, account, region, nil) + adapter := NewEC2CapacityReservationFleetAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-capacity-reservation_test.go b/aws-source/adapters/ec2-capacity-reservation_test.go index 3ca0e00d..e2819dc4 100644 --- a/aws-source/adapters/ec2-capacity-reservation_test.go +++ b/aws-source/adapters/ec2-capacity-reservation_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestCapacityReservationOutputMapper(t *testing.T) { @@ -92,7 +93,7 @@ func TestCapacityReservationOutputMapper(t *testing.T) { func TestNewEC2CapacityReservationAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2CapacityReservationAdapter(client, account, region, nil) + adapter := NewEC2CapacityReservationAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-egress-only-internet-gateway_test.go b/aws-source/adapters/ec2-egress-only-internet-gateway_test.go index def87479..677b4b73 100644 --- a/aws-source/adapters/ec2-egress-only-internet-gateway_test.go +++ b/aws-source/adapters/ec2-egress-only-internet-gateway_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestEgressOnlyInternetGatewayInputMapperGet(t *testing.T) { @@ -89,7 +90,7 @@ func TestEgressOnlyInternetGatewayOutputMapper(t *testing.T) { func TestNewEC2EgressOnlyInternetGatewayAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2EgressOnlyInternetGatewayAdapter(client, account, region, nil) + adapter := NewEC2EgressOnlyInternetGatewayAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-iam-instance-profile-association_test.go b/aws-source/adapters/ec2-iam-instance-profile-association_test.go index 158e2b5a..72ab47a6 100644 --- a/aws-source/adapters/ec2-iam-instance-profile-association_test.go +++ b/aws-source/adapters/ec2-iam-instance-profile-association_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestIamInstanceProfileAssociationOutputMapper(t *testing.T) { @@ -67,7 +68,7 @@ func TestIamInstanceProfileAssociationOutputMapper(t *testing.T) { func TestNewEC2IamInstanceProfileAssociationAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2IamInstanceProfileAssociationAdapter(client, account, region, nil) + adapter := NewEC2IamInstanceProfileAssociationAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-image_test.go b/aws-source/adapters/ec2-image_test.go index 60bc810b..b642e410 100644 --- a/aws-source/adapters/ec2-image_test.go +++ b/aws-source/adapters/ec2-image_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/overmindtech/cli/sdpcache" ) func TestImageInputMapperGet(t *testing.T) { @@ -107,7 +108,7 @@ func TestImageOutputMapper(t *testing.T) { func TestNewEC2ImageAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2ImageAdapter(client, account, region, nil) + adapter := NewEC2ImageAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-instance-event-window_test.go b/aws-source/adapters/ec2-instance-event-window_test.go index e36a62ef..a8a7eb16 100644 --- a/aws-source/adapters/ec2-instance-event-window_test.go +++ b/aws-source/adapters/ec2-instance-event-window_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestInstanceEventWindowInputMapperGet(t *testing.T) { @@ -109,7 +110,7 @@ func TestInstanceEventWindowOutputMapper(t *testing.T) { func TestNewEC2InstanceEventWindowAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2InstanceEventWindowAdapter(client, account, region, nil) + adapter := NewEC2InstanceEventWindowAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-instance-status_test.go b/aws-source/adapters/ec2-instance-status_test.go index 8e3bf189..212d16a0 100644 --- a/aws-source/adapters/ec2-instance-status_test.go +++ b/aws-source/adapters/ec2-instance-status_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestInstanceStatusInputMapperGet(t *testing.T) { @@ -106,7 +107,7 @@ func TestInstanceStatusOutputMapper(t *testing.T) { func TestNewEC2InstanceStatusAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2InstanceStatusAdapter(client, account, region, nil) + adapter := NewEC2InstanceStatusAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-instance_test.go b/aws-source/adapters/ec2-instance_test.go index af0f8415..88288ed6 100644 --- a/aws-source/adapters/ec2-instance_test.go +++ b/aws-source/adapters/ec2-instance_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestInstanceInputMapperGet(t *testing.T) { @@ -351,7 +352,7 @@ func TestInstanceOutputMapper(t *testing.T) { func TestNewEC2InstanceAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2InstanceAdapter(client, account, region, nil) + adapter := NewEC2InstanceAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-internet-gateway_test.go b/aws-source/adapters/ec2-internet-gateway_test.go index 0de85350..4d2bb563 100644 --- a/aws-source/adapters/ec2-internet-gateway_test.go +++ b/aws-source/adapters/ec2-internet-gateway_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestInternetGatewayInputMapperGet(t *testing.T) { @@ -96,7 +97,7 @@ func TestInternetGatewayOutputMapper(t *testing.T) { func TestNewEC2InternetGatewayAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2InternetGatewayAdapter(client, account, region, nil) + adapter := NewEC2InternetGatewayAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-key-pair_test.go b/aws-source/adapters/ec2-key-pair_test.go index 8096ff98..89522608 100644 --- a/aws-source/adapters/ec2-key-pair_test.go +++ b/aws-source/adapters/ec2-key-pair_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/overmindtech/cli/sdpcache" ) func TestKeyPairInputMapperGet(t *testing.T) { @@ -73,7 +74,7 @@ func TestKeyPairOutputMapper(t *testing.T) { func TestNewEC2KeyPairAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2KeyPairAdapter(client, account, region, nil) + adapter := NewEC2KeyPairAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-launch-template-version_test.go b/aws-source/adapters/ec2-launch-template-version_test.go index 29ca9cf7..a86e9d43 100644 --- a/aws-source/adapters/ec2-launch-template-version_test.go +++ b/aws-source/adapters/ec2-launch-template-version_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestLaunchTemplateVersionInputMapperGet(t *testing.T) { @@ -204,7 +205,7 @@ func TestLaunchTemplateVersionOutputMapper(t *testing.T) { func TestNewEC2LaunchTemplateVersionAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2LaunchTemplateVersionAdapter(client, account, region, nil) + adapter := NewEC2LaunchTemplateVersionAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-launch-template_test.go b/aws-source/adapters/ec2-launch-template_test.go index 21710c09..c2f8e4ff 100644 --- a/aws-source/adapters/ec2-launch-template_test.go +++ b/aws-source/adapters/ec2-launch-template_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/overmindtech/cli/sdpcache" ) func TestLaunchTemplateInputMapperGet(t *testing.T) { @@ -67,7 +68,7 @@ func TestLaunchTemplateOutputMapper(t *testing.T) { func TestNewEC2LaunchTemplateAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2LaunchTemplateAdapter(client, account, region, nil) + adapter := NewEC2LaunchTemplateAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-nat-gateway_test.go b/aws-source/adapters/ec2-nat-gateway_test.go index 2fba9f97..70e4f351 100644 --- a/aws-source/adapters/ec2-nat-gateway_test.go +++ b/aws-source/adapters/ec2-nat-gateway_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestNatGatewayInputMapperGet(t *testing.T) { @@ -150,7 +151,7 @@ func TestNatGatewayOutputMapper(t *testing.T) { func TestNewEC2NatGatewayAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2NatGatewayAdapter(client, account, region, nil) + adapter := NewEC2NatGatewayAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-network-acl_test.go b/aws-source/adapters/ec2-network-acl_test.go index 651f100f..a0197d24 100644 --- a/aws-source/adapters/ec2-network-acl_test.go +++ b/aws-source/adapters/ec2-network-acl_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestNetworkAclInputMapperGet(t *testing.T) { @@ -132,7 +133,7 @@ func TestNetworkAclOutputMapper(t *testing.T) { func TestNewEC2NetworkAclAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2NetworkAclAdapter(client, account, region, nil) + adapter := NewEC2NetworkAclAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-network-interface-permission.go b/aws-source/adapters/ec2-network-interface-permission.go index 2eea4280..bb25592d 100644 --- a/aws-source/adapters/ec2-network-interface-permission.go +++ b/aws-source/adapters/ec2-network-interface-permission.go @@ -73,7 +73,7 @@ func NewEC2NetworkInterfacePermissionAdapter(client *ec2.Client, accountID strin AccountID: accountID, ItemType: "ec2-network-interface-permission", AdapterMetadata: networkInterfacePermissionAdapterMetadata, - cache: cache, + cache: cache, DescribeFunc: func(ctx context.Context, client *ec2.Client, input *ec2.DescribeNetworkInterfacePermissionsInput) (*ec2.DescribeNetworkInterfacePermissionsOutput, error) { return client.DescribeNetworkInterfacePermissions(ctx, input) }, diff --git a/aws-source/adapters/ec2-network-interface-permission_test.go b/aws-source/adapters/ec2-network-interface-permission_test.go index e5f016d4..0d07c507 100644 --- a/aws-source/adapters/ec2-network-interface-permission_test.go +++ b/aws-source/adapters/ec2-network-interface-permission_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestNetworkInterfacePermissionInputMapperGet(t *testing.T) { @@ -89,7 +90,7 @@ func TestNetworkInterfacePermissionOutputMapper(t *testing.T) { func TestNewEC2NetworkInterfacePermissionAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2NetworkInterfacePermissionAdapter(client, account, region, nil) + adapter := NewEC2NetworkInterfacePermissionAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-network-interface_test.go b/aws-source/adapters/ec2-network-interface_test.go index d6bb2c94..bbf5713b 100644 --- a/aws-source/adapters/ec2-network-interface_test.go +++ b/aws-source/adapters/ec2-network-interface_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestNetworkInterfaceInputMapperGet(t *testing.T) { @@ -281,7 +282,7 @@ func TestNetworkInterfaceOutputMapper(t *testing.T) { func TestNewEC2NetworkInterfaceAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2NetworkInterfaceAdapter(client, account, region, nil) + adapter := NewEC2NetworkInterfaceAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-placement-group_test.go b/aws-source/adapters/ec2-placement-group_test.go index a4ee82c9..f17dfbd9 100644 --- a/aws-source/adapters/ec2-placement-group_test.go +++ b/aws-source/adapters/ec2-placement-group_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/overmindtech/cli/sdpcache" ) func TestPlacementGroupInputMapperGet(t *testing.T) { @@ -74,7 +75,7 @@ func TestPlacementGroupOutputMapper(t *testing.T) { func TestNewEC2PlacementGroupAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2PlacementGroupAdapter(client, account, region, nil) + adapter := NewEC2PlacementGroupAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-reserved-instance_test.go b/aws-source/adapters/ec2-reserved-instance_test.go index a86fa5fa..ae67edc8 100644 --- a/aws-source/adapters/ec2-reserved-instance_test.go +++ b/aws-source/adapters/ec2-reserved-instance_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/overmindtech/cli/sdpcache" ) func TestReservedInstanceInputMapperGet(t *testing.T) { @@ -96,7 +97,7 @@ func TestReservedInstanceOutputMapper(t *testing.T) { func TestNewEC2ReservedInstanceAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2ReservedInstanceAdapter(client, account, region, nil) + adapter := NewEC2ReservedInstanceAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-route-table_test.go b/aws-source/adapters/ec2-route-table_test.go index 8d35427b..468e7a9a 100644 --- a/aws-source/adapters/ec2-route-table_test.go +++ b/aws-source/adapters/ec2-route-table_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestRouteTableInputMapperGet(t *testing.T) { @@ -197,7 +198,7 @@ func TestRouteTableOutputMapper(t *testing.T) { func TestNewEC2RouteTableAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2RouteTableAdapter(client, account, region, nil) + adapter := NewEC2RouteTableAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-security-group-rule_test.go b/aws-source/adapters/ec2-security-group-rule_test.go index eac6c9a9..3fdf07ad 100644 --- a/aws-source/adapters/ec2-security-group-rule_test.go +++ b/aws-source/adapters/ec2-security-group-rule_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestSecurityGroupRuleInputMapperGet(t *testing.T) { @@ -110,7 +111,7 @@ func TestSecurityGroupRuleOutputMapper(t *testing.T) { func TestNewEC2SecurityGroupRuleAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2SecurityGroupRuleAdapter(client, account, region, nil) + adapter := NewEC2SecurityGroupRuleAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-security-group_test.go b/aws-source/adapters/ec2-security-group_test.go index 7b532e7b..9ccb2422 100644 --- a/aws-source/adapters/ec2-security-group_test.go +++ b/aws-source/adapters/ec2-security-group_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestSecurityGroupInputMapperGet(t *testing.T) { @@ -120,7 +121,7 @@ func TestSecurityGroupOutputMapper(t *testing.T) { func TestNewEC2SecurityGroupAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2SecurityGroupAdapter(client, account, region, nil) + adapter := NewEC2SecurityGroupAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-snapshot_test.go b/aws-source/adapters/ec2-snapshot_test.go index 809659c4..4bacc657 100644 --- a/aws-source/adapters/ec2-snapshot_test.go +++ b/aws-source/adapters/ec2-snapshot_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestSnapshotInputMapperGet(t *testing.T) { @@ -93,7 +94,7 @@ func TestSnapshotOutputMapper(t *testing.T) { func TestNewEC2SnapshotAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2SnapshotAdapter(client, account, region, nil) + adapter := NewEC2SnapshotAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-subnet_test.go b/aws-source/adapters/ec2-subnet_test.go index ded05d8e..64e121eb 100644 --- a/aws-source/adapters/ec2-subnet_test.go +++ b/aws-source/adapters/ec2-subnet_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestSubnetInputMapperGet(t *testing.T) { @@ -113,7 +114,7 @@ func TestSubnetOutputMapper(t *testing.T) { func TestNewEC2SubnetAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2SubnetAdapter(client, account, region, nil) + adapter := NewEC2SubnetAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-volume-status_test.go b/aws-source/adapters/ec2-volume-status_test.go index fa271595..7c01e8e3 100644 --- a/aws-source/adapters/ec2-volume-status_test.go +++ b/aws-source/adapters/ec2-volume-status_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestVolumeStatusInputMapperGet(t *testing.T) { @@ -120,7 +121,7 @@ func TestVolumeStatusOutputMapper(t *testing.T) { func TestNewEC2VolumeStatusAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2VolumeAdapter(client, account, region, nil) + adapter := NewEC2VolumeAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-volume.go b/aws-source/adapters/ec2-volume.go index 511e4c2e..64351895 100644 --- a/aws-source/adapters/ec2-volume.go +++ b/aws-source/adapters/ec2-volume.go @@ -74,7 +74,7 @@ func NewEC2VolumeAdapter(client *ec2.Client, accountID string, region string, ca AccountID: accountID, ItemType: "ec2-volume", AdapterMetadata: volumeAdapterMetadata, - cache: cache, + cache: cache, DescribeFunc: func(ctx context.Context, client *ec2.Client, input *ec2.DescribeVolumesInput) (*ec2.DescribeVolumesOutput, error) { return client.DescribeVolumes(ctx, input) }, diff --git a/aws-source/adapters/ec2-volume_test.go b/aws-source/adapters/ec2-volume_test.go index beefd23c..9dd02fcd 100644 --- a/aws-source/adapters/ec2-volume_test.go +++ b/aws-source/adapters/ec2-volume_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestVolumeInputMapperGet(t *testing.T) { @@ -102,7 +103,7 @@ func TestVolumeOutputMapper(t *testing.T) { func TestNewEC2VolumeAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2VolumeAdapter(client, account, region, nil) + adapter := NewEC2VolumeAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-vpc-endpoint_test.go b/aws-source/adapters/ec2-vpc-endpoint_test.go index 0b426b61..8010afaa 100644 --- a/aws-source/adapters/ec2-vpc-endpoint_test.go +++ b/aws-source/adapters/ec2-vpc-endpoint_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestVpcEndpointInputMapperGet(t *testing.T) { @@ -147,7 +148,7 @@ func TestVpcEndpointOutputMapper(t *testing.T) { func TestNewEC2VpcEndpointAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2VpcEndpointAdapter(client, account, region, nil) + adapter := NewEC2VpcEndpointAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-vpc-peering-connection_test.go b/aws-source/adapters/ec2-vpc-peering-connection_test.go index b51098ac..4a31a669 100644 --- a/aws-source/adapters/ec2-vpc-peering-connection_test.go +++ b/aws-source/adapters/ec2-vpc-peering-connection_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestVpcPeeringConnectionOutputMapper(t *testing.T) { @@ -103,7 +104,7 @@ func TestVpcPeeringConnectionOutputMapper(t *testing.T) { func TestNewEC2VpcPeeringConnectionAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2VpcPeeringConnectionAdapter(client, account, region, nil) + adapter := NewEC2VpcPeeringConnectionAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ec2-vpc_test.go b/aws-source/adapters/ec2-vpc_test.go index 8abe6319..edd5cff9 100644 --- a/aws-source/adapters/ec2-vpc_test.go +++ b/aws-source/adapters/ec2-vpc_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/overmindtech/cli/sdpcache" ) func TestVpcInputMapperGet(t *testing.T) { @@ -99,7 +100,7 @@ func TestVpcOutputMapper(t *testing.T) { func TestNewEC2VpcAdapter(t *testing.T) { client, account, region := ec2GetAutoConfig(t) - adapter := NewEC2VpcAdapter(client, account, region, nil) + adapter := NewEC2VpcAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ecs-capacity-provider_test.go b/aws-source/adapters/ecs-capacity-provider_test.go index a3febc32..faf453a8 100644 --- a/aws-source/adapters/ecs-capacity-provider_test.go +++ b/aws-source/adapters/ecs-capacity-provider_test.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t *ecsTestClient) DescribeCapacityProviders(ctx context.Context, params *ecs.DescribeCapacityProvidersInput, optFns ...func(*ecs.Options)) (*ecs.DescribeCapacityProvidersOutput, error) { @@ -122,7 +123,7 @@ func TestCapacityProviderOutputMapper(t *testing.T) { } func TestCapacityProviderAdapter(t *testing.T) { - adapter := NewECSCapacityProviderAdapter(&ecsTestClient{}, "", "", nil) + adapter := NewECSCapacityProviderAdapter(&ecsTestClient{}, "", "", sdpcache.NewNoOpCache()) stream := discovery.NewRecordingQueryResultStream() adapter.ListStream(context.Background(), "*", false, stream) @@ -142,7 +143,7 @@ func TestNewECSCapacityProviderAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := ecs.NewFromConfig(config) - adapter := NewECSCapacityProviderAdapter(client, account, region, nil) + adapter := NewECSCapacityProviderAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ecs-cluster_test.go b/aws-source/adapters/ecs-cluster_test.go index 579c9b89..ebe86965 100644 --- a/aws-source/adapters/ecs-cluster_test.go +++ b/aws-source/adapters/ecs-cluster_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t *ecsTestClient) DescribeClusters(ctx context.Context, params *ecs.DescribeClustersInput, optFns ...func(*ecs.Options)) (*ecs.DescribeClustersOutput, error) { @@ -147,7 +148,7 @@ func TestECSClusterGetFunc(t *testing.T) { func TestECSNewECSClusterAdapter(t *testing.T) { client, account, region := ecsGetAutoConfig(t) - adapter := NewECSClusterAdapter(client, account, region, nil) + adapter := NewECSClusterAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ecs-container-instance_test.go b/aws-source/adapters/ecs-container-instance_test.go index 5e5ab52b..73fb9818 100644 --- a/aws-source/adapters/ecs-container-instance_test.go +++ b/aws-source/adapters/ecs-container-instance_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t *ecsTestClient) DescribeContainerInstances(ctx context.Context, params *ecs.DescribeContainerInstancesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeContainerInstancesOutput, error) { @@ -349,7 +350,7 @@ func TestContainerInstanceGetFunc(t *testing.T) { func TestNewECSContainerInstanceAdapter(t *testing.T) { client, account, region := ecsGetAutoConfig(t) - adapter := NewECSContainerInstanceAdapter(client, account, region, nil) + adapter := NewECSContainerInstanceAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ecs-service_test.go b/aws-source/adapters/ecs-service_test.go index 8089c705..41b6789d 100644 --- a/aws-source/adapters/ecs-service_test.go +++ b/aws-source/adapters/ecs-service_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t *ecsTestClient) DescribeServices(ctx context.Context, params *ecs.DescribeServicesInput, optFns ...func(*ecs.Options)) (*ecs.DescribeServicesOutput, error) { @@ -311,7 +312,7 @@ func TestServiceGetFunc(t *testing.T) { func TestNewECSServiceAdapter(t *testing.T) { client, account, region := ecsGetAutoConfig(t) - adapter := NewECSServiceAdapter(client, account, region, nil) + adapter := NewECSServiceAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ecs-task-definition_test.go b/aws-source/adapters/ecs-task-definition_test.go index 44102220..7f65a16d 100644 --- a/aws-source/adapters/ecs-task-definition_test.go +++ b/aws-source/adapters/ecs-task-definition_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t *ecsTestClient) DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) { @@ -255,7 +256,7 @@ func TestTaskDefinitionGetFunc(t *testing.T) { func TestNewECSTaskDefinitionAdapter(t *testing.T) { client, account, region := ecsGetAutoConfig(t) - adapter := NewECSTaskDefinitionAdapter(client, account, region, nil) + adapter := NewECSTaskDefinitionAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ecs-task_test.go b/aws-source/adapters/ecs-task_test.go index 27b342c7..0f460c8b 100644 --- a/aws-source/adapters/ecs-task_test.go +++ b/aws-source/adapters/ecs-task_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t *ecsTestClient) DescribeTasks(ctx context.Context, params *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { @@ -232,7 +233,7 @@ func TestTaskGetFunc(t *testing.T) { func TestNewECSTaskAdapter(t *testing.T) { client, account, region := ecsGetAutoConfig(t) - adapter := NewECSTaskAdapter(client, account, region, nil) + adapter := NewECSTaskAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ecs_test.go b/aws-source/adapters/ecs_test.go index 6bb72d76..e8471531 100644 --- a/aws-source/adapters/ecs_test.go +++ b/aws-source/adapters/ecs_test.go @@ -1,8 +1,9 @@ package adapters import ( - "github.com/aws/aws-sdk-go-v2/service/ecs" "testing" + + "github.com/aws/aws-sdk-go-v2/service/ecs" ) type ecsTestClient struct{} diff --git a/aws-source/adapters/efs-access-point_test.go b/aws-source/adapters/efs-access-point_test.go index 5e505951..9327b0d0 100644 --- a/aws-source/adapters/efs-access-point_test.go +++ b/aws-source/adapters/efs-access-point_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/efs" "github.com/aws/aws-sdk-go-v2/service/efs/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestAccessPointOutputMapper(t *testing.T) { @@ -82,7 +83,7 @@ func TestAccessPointOutputMapper(t *testing.T) { func TestNewEFSAccessPointAdapter(t *testing.T) { client, account, region := efsGetAutoConfig(t) - adapter := NewEFSAccessPointAdapter(client, account, region, nil) + adapter := NewEFSAccessPointAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/efs-file-system_test.go b/aws-source/adapters/efs-file-system_test.go index b927dc25..b0528b98 100644 --- a/aws-source/adapters/efs-file-system_test.go +++ b/aws-source/adapters/efs-file-system_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/efs" "github.com/aws/aws-sdk-go-v2/service/efs/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestFileSystemOutputMapper(t *testing.T) { @@ -93,7 +94,7 @@ func TestFileSystemOutputMapper(t *testing.T) { func TestNewEFSFileSystemAdapter(t *testing.T) { client, account, region := efsGetAutoConfig(t) - adapter := NewEFSFileSystemAdapter(client, account, region, nil) + adapter := NewEFSFileSystemAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/eks-addon_test.go b/aws-source/adapters/eks-addon_test.go index e01eabeb..eeb544ea 100644 --- a/aws-source/adapters/eks-addon_test.go +++ b/aws-source/adapters/eks-addon_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/eks" "github.com/aws/aws-sdk-go-v2/service/eks/types" + "github.com/overmindtech/cli/sdpcache" ) var AddonTestClient = EKSTestClient{ @@ -49,7 +50,7 @@ func TestAddonGetFunc(t *testing.T) { func TestNewEKSAddonAdapter(t *testing.T) { client, account, region := eksGetAutoConfig(t) - adapter := NewEKSAddonAdapter(client, account, region, nil) + adapter := NewEKSAddonAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/eks-cluster_test.go b/aws-source/adapters/eks-cluster_test.go index 1a3eefcc..95aa92a0 100644 --- a/aws-source/adapters/eks-cluster_test.go +++ b/aws-source/adapters/eks-cluster_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/eks" "github.com/aws/aws-sdk-go-v2/service/eks/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) var ClusterClient = EKSTestClient{ @@ -206,7 +207,7 @@ func TestClusterGetFunc(t *testing.T) { func TestNewEKSClusterAdapter(t *testing.T) { client, account, region := eksGetAutoConfig(t) - adapter := NewEKSClusterAdapter(client, account, region, nil) + adapter := NewEKSClusterAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/eks-fargate-profile_test.go b/aws-source/adapters/eks-fargate-profile_test.go index ea718bda..ab107381 100644 --- a/aws-source/adapters/eks-fargate-profile_test.go +++ b/aws-source/adapters/eks-fargate-profile_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/eks" "github.com/aws/aws-sdk-go-v2/service/eks/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) var FargateTestClient = EKSTestClient{ @@ -67,7 +68,7 @@ func TestFargateProfileGetFunc(t *testing.T) { func TestNewEKSFargateProfileAdapter(t *testing.T) { client, account, region := eksGetAutoConfig(t) - adapter := NewEKSFargateProfileAdapter(client, account, region, nil) + adapter := NewEKSFargateProfileAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/eks-nodegroup_test.go b/aws-source/adapters/eks-nodegroup_test.go index 51ddc59b..2a64ba1d 100644 --- a/aws-source/adapters/eks-nodegroup_test.go +++ b/aws-source/adapters/eks-nodegroup_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/eks" "github.com/aws/aws-sdk-go-v2/service/eks/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) var NodeGroupClient = EKSTestClient{ @@ -132,7 +133,7 @@ func TestNodegroupGetFunc(t *testing.T) { func TestNewEKSNodegroupAdapter(t *testing.T) { client, account, region := eksGetAutoConfig(t) - adapter := NewEKSNodegroupAdapter(client, account, region, nil) + adapter := NewEKSNodegroupAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/elbv2-listener.go b/aws-source/adapters/elbv2-listener.go index 0c0f8ff4..6047d67c 100644 --- a/aws-source/adapters/elbv2-listener.go +++ b/aws-source/adapters/elbv2-listener.go @@ -140,7 +140,7 @@ func NewELBv2ListenerAdapter(client elbv2Client, accountID string, region string AccountID: accountID, ItemType: "elbv2-listener", AdapterMetadata: elbv2ListenerAdapterMetadata, - cache: cache, + cache: cache, DescribeFunc: func(ctx context.Context, client elbv2Client, input *elbv2.DescribeListenersInput) (*elbv2.DescribeListenersOutput, error) { return client.DescribeListeners(ctx, input) }, diff --git a/aws-source/adapters/elbv2-load-balancer.go b/aws-source/adapters/elbv2-load-balancer.go index 615c75e6..59967eb2 100644 --- a/aws-source/adapters/elbv2-load-balancer.go +++ b/aws-source/adapters/elbv2-load-balancer.go @@ -257,7 +257,7 @@ func NewELBv2LoadBalancerAdapter(client elbv2Client, accountID string, region st AccountID: accountID, ItemType: "elbv2-load-balancer", AdapterMetadata: loadBalancerAdapterMetadata, - cache: cache, + cache: cache, DescribeFunc: func(ctx context.Context, client elbv2Client, input *elbv2.DescribeLoadBalancersInput) (*elbv2.DescribeLoadBalancersOutput, error) { return client.DescribeLoadBalancers(ctx, input) }, diff --git a/aws-source/adapters/elbv2-rule.go b/aws-source/adapters/elbv2-rule.go index 292d5caa..99aa3795 100644 --- a/aws-source/adapters/elbv2-rule.go +++ b/aws-source/adapters/elbv2-rule.go @@ -83,7 +83,7 @@ func NewELBv2RuleAdapter(client elbv2Client, accountID string, region string, ca AccountID: accountID, ItemType: "elbv2-rule", AdapterMetadata: ruleAdapterMetadata, - cache: cache, + cache: cache, DescribeFunc: func(ctx context.Context, client elbv2Client, input *elbv2.DescribeRulesInput) (*elbv2.DescribeRulesOutput, error) { return client.DescribeRules(ctx, input) }, diff --git a/aws-source/adapters/elbv2-rule_test.go b/aws-source/adapters/elbv2-rule_test.go index 93f21c2a..6117bb2c 100644 --- a/aws-source/adapters/elbv2-rule_test.go +++ b/aws-source/adapters/elbv2-rule_test.go @@ -11,6 +11,7 @@ import ( "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestRuleOutputMapper(t *testing.T) { @@ -96,9 +97,9 @@ func TestNewELBv2RuleAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := elbv2.NewFromConfig(config) - lbSource := NewELBv2LoadBalancerAdapter(client, account, region, nil) - listenerSource := NewELBv2ListenerAdapter(client, account, region, nil) - ruleSource := NewELBv2RuleAdapter(client, account, region, nil) + lbSource := NewELBv2LoadBalancerAdapter(client, account, region, sdpcache.NewNoOpCache()) + listenerSource := NewELBv2ListenerAdapter(client, account, region, sdpcache.NewNoOpCache()) + ruleSource := NewELBv2RuleAdapter(client, account, region, sdpcache.NewNoOpCache()) stream := discovery.NewRecordingQueryResultStream() lbSource.ListStream(context.Background(), lbSource.Scopes()[0], false, stream) diff --git a/aws-source/adapters/elbv2-target-group_test.go b/aws-source/adapters/elbv2-target-group_test.go index 60eb1c98..88a48ab7 100644 --- a/aws-source/adapters/elbv2-target-group_test.go +++ b/aws-source/adapters/elbv2-target-group_test.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestTargetGroupOutputMapper(t *testing.T) { @@ -88,7 +89,7 @@ func TestNewELBv2TargetGroupAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := elbv2.NewFromConfig(config) - adapter := NewELBv2TargetGroupAdapter(client, account, region, nil) + adapter := NewELBv2TargetGroupAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/iam-group_test.go b/aws-source/adapters/iam-group_test.go index 842eab82..ec6651fe 100644 --- a/aws-source/adapters/iam-group_test.go +++ b/aws-source/adapters/iam-group_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/iam" "github.com/aws/aws-sdk-go-v2/service/iam/types" + "github.com/overmindtech/cli/sdpcache" ) func TestGroupItemMapper(t *testing.T) { @@ -37,7 +38,7 @@ func TestNewIAMGroupAdapter(t *testing.T) { o.RetryMaxAttempts = 10 }) - adapter := NewIAMGroupAdapter(client, account, nil) + adapter := NewIAMGroupAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/iam-instance-profile_test.go b/aws-source/adapters/iam-instance-profile_test.go index e08a91ae..092e9d58 100644 --- a/aws-source/adapters/iam-instance-profile_test.go +++ b/aws-source/adapters/iam-instance-profile_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/iam" "github.com/aws/aws-sdk-go-v2/service/iam/types" + "github.com/overmindtech/cli/sdpcache" ) func TestInstanceProfileItemMapper(t *testing.T) { @@ -57,7 +58,7 @@ func TestNewIAMInstanceProfileAdapter(t *testing.T) { o.RetryMaxAttempts = 10 }) - adapter := NewIAMInstanceProfileAdapter(client, account, nil) + adapter := NewIAMInstanceProfileAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/iam-policy.go b/aws-source/adapters/iam-policy.go index 9df687e0..009cb900 100644 --- a/aws-source/adapters/iam-policy.go +++ b/aws-source/adapters/iam-policy.go @@ -279,7 +279,7 @@ func NewIAMPolicyAdapter(client IAMClient, accountID string, cache sdpcache.Cach Region: "", // IAM policies aren't tied to a region CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time AdapterMetadata: policyAdapterMetadata, - cache: cache, + cache: cache, SupportGlobalResources: true, InputMapperList: func(scope string) (*iam.ListPoliciesInput, error) { var iamScope types.PolicyScopeType diff --git a/aws-source/adapters/iam-policy_test.go b/aws-source/adapters/iam-policy_test.go index 5afc260b..cbdf09cc 100644 --- a/aws-source/adapters/iam-policy_test.go +++ b/aws-source/adapters/iam-policy_test.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t *TestIAMClient) GetPolicy(ctx context.Context, params *iam.GetPolicyInput, optFns ...func(*iam.Options)) (*iam.GetPolicyOutput, error) { @@ -312,7 +313,7 @@ func TestNewIAMPolicyAdapter(t *testing.T) { o.RetryMaxAttempts = 10 }) - adapter := NewIAMPolicyAdapter(client, account, nil) + adapter := NewIAMPolicyAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/iam-role.go b/aws-source/adapters/iam-role.go index 9d19c6c8..58d735ac 100644 --- a/aws-source/adapters/iam-role.go +++ b/aws-source/adapters/iam-role.go @@ -254,6 +254,7 @@ func NewIAMRoleAdapter(client IAMClient, accountID string, cache sdpcache.Cache) ItemType: "iam-role", Client: client, CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time + cache: cache, AccountID: accountID, GetFunc: func(ctx context.Context, client IAMClient, scope, query string) (*RoleDetails, error) { return roleGetFunc(ctx, client, scope, query) diff --git a/aws-source/adapters/iam-role_test.go b/aws-source/adapters/iam-role_test.go index a7557308..8b581863 100644 --- a/aws-source/adapters/iam-role_test.go +++ b/aws-source/adapters/iam-role_test.go @@ -13,6 +13,7 @@ import ( "github.com/micahhausler/aws-iam-policy/policy" "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t *TestIAMClient) GetRole(ctx context.Context, params *iam.GetRoleInput, optFns ...func(*iam.Options)) (*iam.GetRoleOutput, error) { @@ -140,7 +141,7 @@ func TestRoleGetFunc(t *testing.T) { } func TestRoleListFunc(t *testing.T) { - adapter := NewIAMRoleAdapter(&TestIAMClient{}, "foo", nil) + adapter := NewIAMRoleAdapter(&TestIAMClient{}, "foo", sdpcache.NewNoOpCache()) stream := discovery.NewRecordingQueryResultStream() adapter.ListStream(context.Background(), "foo", false, stream) @@ -253,7 +254,7 @@ func TestNewIAMRoleAdapter(t *testing.T) { o.RetryMaxAttempts = 10 }) - adapter := NewIAMRoleAdapter(client, account, nil) + adapter := NewIAMRoleAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/iam-user.go b/aws-source/adapters/iam-user.go index f21658e3..c96801f2 100644 --- a/aws-source/adapters/iam-user.go +++ b/aws-source/adapters/iam-user.go @@ -139,6 +139,7 @@ func NewIAMUserAdapter(client IAMClient, accountID string, cache sdpcache.Cache) return &GetListAdapterV2[*iam.ListUsersInput, *iam.ListUsersOutput, *UserDetails, IAMClient, *iam.Options]{ ItemType: "iam-user", Client: client, + cache: cache, CacheDuration: 3 * time.Hour, // IAM has very low rate limits, we need to cache for a long time AccountID: accountID, GetFunc: func(ctx context.Context, client IAMClient, scope, query string) (*UserDetails, error) { diff --git a/aws-source/adapters/iam-user_test.go b/aws-source/adapters/iam-user_test.go index c9ccde4e..3d3bbe71 100644 --- a/aws-source/adapters/iam-user_test.go +++ b/aws-source/adapters/iam-user_test.go @@ -13,6 +13,7 @@ import ( "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func (t *TestIAMClient) ListGroupsForUser(ctx context.Context, params *iam.ListGroupsForUserInput, optFns ...func(*iam.Options)) (*iam.ListGroupsForUserOutput, error) { @@ -139,7 +140,7 @@ func TestUserGetFunc(t *testing.T) { } func TestUserListFunc(t *testing.T) { - adapter := NewIAMUserAdapter(&TestIAMClient{}, "foo", nil) + adapter := NewIAMUserAdapter(&TestIAMClient{}, "foo", sdpcache.NewNoOpCache()) stream := discovery.NewRecordingQueryResultStream() adapter.ListStream(context.Background(), "foo", false, stream) @@ -227,7 +228,7 @@ func TestNewIAMUserAdapter(t *testing.T) { o.RetryMaxAttempts = 10 }) - adapter := NewIAMUserAdapter(client, account, nil) + adapter := NewIAMUserAdapter(client, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/integration/apigateway/apigateway_test.go b/aws-source/adapters/integration/apigateway/apigateway_test.go index 01e455b0..25211873 100644 --- a/aws-source/adapters/integration/apigateway/apigateway_test.go +++ b/aws-source/adapters/integration/apigateway/apigateway_test.go @@ -8,6 +8,7 @@ import ( "github.com/overmindtech/cli/aws-source/adapters" "github.com/overmindtech/cli/aws-source/adapters/integration" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func APIGateway(t *testing.T) { @@ -30,70 +31,70 @@ func APIGateway(t *testing.T) { // Resources ------------------------------------------------------------------------------------------------------ - restApiSource := adapters.NewAPIGatewayRestApiAdapter(testClient, accountID, testAWSConfig.Region, nil) + restApiSource := adapters.NewAPIGatewayRestApiAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) err = restApiSource.Validate() if err != nil { t.Fatalf("failed to validate APIGateway restApi adapter: %v", err) } - resourceApiSource := adapters.NewAPIGatewayResourceAdapter(testClient, accountID, testAWSConfig.Region, nil) + resourceApiSource := adapters.NewAPIGatewayResourceAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) err = resourceApiSource.Validate() if err != nil { t.Fatalf("failed to validate APIGateway resource adapter: %v", err) } - methodSource := adapters.NewAPIGatewayMethodAdapter(testClient, accountID, testAWSConfig.Region, nil) + methodSource := adapters.NewAPIGatewayMethodAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) err = methodSource.Validate() if err != nil { t.Fatalf("failed to validate APIGateway method adapter: %v", err) } - methodResponseSource := adapters.NewAPIGatewayMethodResponseAdapter(testClient, accountID, testAWSConfig.Region, nil) + methodResponseSource := adapters.NewAPIGatewayMethodResponseAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) err = methodResponseSource.Validate() if err != nil { t.Fatalf("failed to validate APIGateway method response adapter: %v", err) } - integrationSource := adapters.NewAPIGatewayIntegrationAdapter(testClient, accountID, testAWSConfig.Region, nil) + integrationSource := adapters.NewAPIGatewayIntegrationAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) err = integrationSource.Validate() if err != nil { t.Fatalf("failed to validate APIGateway integration adapter: %v", err) } - apiKeySource := adapters.NewAPIGatewayApiKeyAdapter(testClient, accountID, testAWSConfig.Region, nil) + apiKeySource := adapters.NewAPIGatewayApiKeyAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) err = apiKeySource.Validate() if err != nil { t.Fatalf("failed to validate APIGateway API key adapter: %v", err) } - authorizerSource := adapters.NewAPIGatewayAuthorizerAdapter(testClient, accountID, testAWSConfig.Region, nil) + authorizerSource := adapters.NewAPIGatewayAuthorizerAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) err = authorizerSource.Validate() if err != nil { t.Fatalf("failed to validate APIGateway authorizer adapter: %v", err) } - deploymentSource := adapters.NewAPIGatewayDeploymentAdapter(testClient, accountID, testAWSConfig.Region, nil) + deploymentSource := adapters.NewAPIGatewayDeploymentAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) err = deploymentSource.Validate() if err != nil { t.Fatalf("failed to validate APIGateway deployment adapter: %v", err) } - stageSource := adapters.NewAPIGatewayStageAdapter(testClient, accountID, testAWSConfig.Region, nil) + stageSource := adapters.NewAPIGatewayStageAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) err = stageSource.Validate() if err != nil { t.Fatalf("failed to validate APIGateway stage adapter: %v", err) } - modelSource := adapters.NewAPIGatewayModelAdapter(testClient, accountID, testAWSConfig.Region, nil) + modelSource := adapters.NewAPIGatewayModelAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) err = modelSource.Validate() if err != nil { diff --git a/aws-source/adapters/integration/ec2/instance_test.go b/aws-source/adapters/integration/ec2/instance_test.go index 6aa87d87..730339ef 100644 --- a/aws-source/adapters/integration/ec2/instance_test.go +++ b/aws-source/adapters/integration/ec2/instance_test.go @@ -9,6 +9,7 @@ import ( "github.com/overmindtech/cli/aws-source/adapters/integration" "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func searchSync(adapter discovery.SearchStreamableAdapter, ctx context.Context, scope, query string, ignoreCache bool) ([]*sdp.Item, error) { @@ -53,7 +54,7 @@ func EC2(t *testing.T) { t.Log("Running EC2 integration test") - instanceAdapter := adapters.NewEC2InstanceAdapter(testClient, accountID, testAWSConfig.Region, nil) + instanceAdapter := adapters.NewEC2InstanceAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) err = instanceAdapter.Validate() if err != nil { diff --git a/aws-source/adapters/integration/kms/kms_test.go b/aws-source/adapters/integration/kms/kms_test.go index 3597f73a..3733d15c 100644 --- a/aws-source/adapters/integration/kms/kms_test.go +++ b/aws-source/adapters/integration/kms/kms_test.go @@ -10,6 +10,7 @@ import ( "github.com/overmindtech/cli/aws-source/adapters/integration" "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func searchSync(adapter discovery.SearchStreamableAdapter, ctx context.Context, scope, query string, ignoreCache bool) ([]*sdp.Item, error) { @@ -54,13 +55,13 @@ func KMS(t *testing.T) { t.Log("Running KMS integration test") - keySource := adapters.NewKMSKeyAdapter(testClient, accountID, testAWSConfig.Region, nil) + keySource := adapters.NewKMSKeyAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) - aliasSource := adapters.NewKMSAliasAdapter(testClient, accountID, testAWSConfig.Region, nil) + aliasSource := adapters.NewKMSAliasAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) - grantSource := adapters.NewKMSGrantAdapter(testClient, accountID, testAWSConfig.Region, nil) + grantSource := adapters.NewKMSGrantAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) - keyPolicySource := adapters.NewKMSKeyPolicyAdapter(testClient, accountID, testAWSConfig.Region, nil) + keyPolicySource := adapters.NewKMSKeyPolicyAdapter(testClient, accountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) err = keySource.Validate() if err != nil { diff --git a/aws-source/adapters/integration/networkmanager/networkmanager_test.go b/aws-source/adapters/integration/networkmanager/networkmanager_test.go index 25babc78..7a2f490f 100644 --- a/aws-source/adapters/integration/networkmanager/networkmanager_test.go +++ b/aws-source/adapters/integration/networkmanager/networkmanager_test.go @@ -10,6 +10,7 @@ import ( "github.com/overmindtech/cli/aws-source/adapters/integration" "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func searchSync(adapter discovery.SearchStreamableAdapter, ctx context.Context, scope, query string, ignoreCache bool) ([]*sdp.Item, error) { @@ -42,32 +43,32 @@ func NetworkManager(t *testing.T) { t.Logf("Running NetworkManager integration tests") - globalNetworkSource := adapters.NewNetworkManagerGlobalNetworkAdapter(testClient, accountID, nil) + globalNetworkSource := adapters.NewNetworkManagerGlobalNetworkAdapter(testClient, accountID, sdpcache.NewNoOpCache()) if err := globalNetworkSource.Validate(); err != nil { t.Fatalf("failed to validate NetworkManager global network adapter: %v", err) } - siteSource := adapters.NewNetworkManagerSiteAdapter(testClient, accountID, nil) + siteSource := adapters.NewNetworkManagerSiteAdapter(testClient, accountID, sdpcache.NewNoOpCache()) if err := siteSource.Validate(); err != nil { t.Fatalf("failed to validate NetworkManager site adapter: %v", err) } - linkSource := adapters.NewNetworkManagerLinkAdapter(testClient, accountID, nil) + linkSource := adapters.NewNetworkManagerLinkAdapter(testClient, accountID, sdpcache.NewNoOpCache()) if err := linkSource.Validate(); err != nil { t.Fatalf("failed to validate NetworkManager link adapter: %v", err) } - linkAssociationSource := adapters.NewNetworkManagerLinkAssociationAdapter(testClient, accountID, nil) + linkAssociationSource := adapters.NewNetworkManagerLinkAssociationAdapter(testClient, accountID, sdpcache.NewNoOpCache()) if err := linkAssociationSource.Validate(); err != nil { t.Fatalf("failed to validate NetworkManager link association adapter: %v", err) } - connectionSource := adapters.NewNetworkManagerConnectionAdapter(testClient, accountID, nil) + connectionSource := adapters.NewNetworkManagerConnectionAdapter(testClient, accountID, sdpcache.NewNoOpCache()) if err := connectionSource.Validate(); err != nil { t.Fatalf("failed to validate NetworkManager connection adapter: %v", err) } - deviceSource := adapters.NewNetworkManagerDeviceAdapter(testClient, accountID, nil) + deviceSource := adapters.NewNetworkManagerDeviceAdapter(testClient, accountID, sdpcache.NewNoOpCache()) if err := deviceSource.Validate(); err != nil { t.Fatalf("failed to validate NetworkManager device adapter: %v", err) } diff --git a/aws-source/adapters/integration/ssm/main_test.go b/aws-source/adapters/integration/ssm/main_test.go index b9e27f3f..db6b64ab 100644 --- a/aws-source/adapters/integration/ssm/main_test.go +++ b/aws-source/adapters/integration/ssm/main_test.go @@ -15,6 +15,7 @@ import ( "github.com/overmindtech/cli/aws-source/adapters" "github.com/overmindtech/cli/aws-source/adapters/integration" "github.com/overmindtech/cli/discovery" + "github.com/overmindtech/cli/sdpcache" "github.com/overmindtech/cli/tracing" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -120,7 +121,7 @@ func TestIntegrationSSM(t *testing.T) { client := ssm.NewFromConfig(testAWSConfig.Config) scope := testAWSConfig.AccountID + "." + testAWSConfig.Region - adapter := adapters.NewSSMParameterAdapter(client, testAWSConfig.AccountID, testAWSConfig.Region, nil) + adapter := adapters.NewSSMParameterAdapter(client, testAWSConfig.AccountID, testAWSConfig.Region, sdpcache.NewNoOpCache()) ctx, span := tracer.Start(ctx, "SSM.List") defer span.End() diff --git a/aws-source/adapters/kms-alias_test.go b/aws-source/adapters/kms-alias_test.go index 86fbfd1b..4b50aa6c 100644 --- a/aws-source/adapters/kms-alias_test.go +++ b/aws-source/adapters/kms-alias_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestAliasOutputMapper(t *testing.T) { @@ -56,7 +57,7 @@ func TestNewKMSAliasAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := kms.NewFromConfig(config) - adapter := NewKMSAliasAdapter(client, account, region, nil) + adapter := NewKMSAliasAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/kms-custom-key-store_test.go b/aws-source/adapters/kms-custom-key-store_test.go index 5e38c2f0..10a420b0 100644 --- a/aws-source/adapters/kms-custom-key-store_test.go +++ b/aws-source/adapters/kms-custom-key-store_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestCustomKeyStoreOutputMapper(t *testing.T) { @@ -59,7 +60,7 @@ func TestNewKMSCustomKeyStoreAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := kms.NewFromConfig(config) - adapter := NewKMSCustomKeyStoreAdapter(client, account, region, nil) + adapter := NewKMSCustomKeyStoreAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/kms-grant.go b/aws-source/adapters/kms-grant.go index 487d47fa..66b8bd49 100644 --- a/aws-source/adapters/kms-grant.go +++ b/aws-source/adapters/kms-grant.go @@ -97,15 +97,27 @@ func grantOutputMapper(ctx context.Context, _ *kms.Client, scope string, _ *kms. dynamodb.us-west-2.amazonaws.com - The following are not supported + The following are not supported (we skip them silently): - arn:aws:iam::account:root - arn:aws:sts::account:federated-user/user-name - arn:aws:sts::account:assumed-role/role-name/role-session-name - arn:aws:sts::account:self - - dynamodb.us-west-2.amazonaws.com => this will cause an error in ARN parsing + - Service principals like dynamodb.us-west-2.amazonaws.com (not ARNs, not linkable) */ for _, principal := range principals { + // Skip AWS service principals (e.g. "rds.eu-west-2.amazonaws.com", + // "dynamodb.us-west-2.amazonaws.com"). These are DNS-style identifiers + // for AWS services, not ARNs, and are not linkable to other items. + if isAWSServicePrincipal(principal) { + log.WithFields(log.Fields{ + "input": principal, + "scope": scope, + }).Debug("Skipping AWS service principal (not linkable)") + + continue + } + lIQ := &sdp.LinkedItemQuery{ Query: &sdp.Query{ Method: sdp.QueryMethod_GET, @@ -126,7 +138,7 @@ func grantOutputMapper(ctx context.Context, _ *kms.Client, scope string, _ *kms. "error": errA, "input": principal, "scope": scope, - }).Error("Error parsing principal ARN") + }).Warn("Error parsing principal ARN") continue } @@ -237,3 +249,24 @@ func iamSourceAndQuery(resource string) (string, string) { return adapter, query // user, user-name-with-path } + +// isAWSServicePrincipal returns true if the principal is an AWS service +// principal (e.g. "rds.eu-west-2.amazonaws.com", "dynamodb.us-west-2.amazonaws.com"). +// These are DNS-style identifiers used by AWS services to assume roles or access +// resources, and are not ARNs. +func isAWSServicePrincipal(principal string) bool { + // Service principals don't start with "arn:" and end with a partition-specific + // DNS suffix. + if strings.HasPrefix(principal, "arn:") { + return false + } + + // Check all AWS partition DNS suffixes using the shared list + for _, suffix := range GetAllAWSPartitionDNSSuffixes() { + if strings.HasSuffix(principal, "."+suffix) { + return true + } + } + + return false +} diff --git a/aws-source/adapters/kms-grant_test.go b/aws-source/adapters/kms-grant_test.go index 8ae2a653..3dd5a08a 100644 --- a/aws-source/adapters/kms-grant_test.go +++ b/aws-source/adapters/kms-grant_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) /* @@ -43,6 +44,82 @@ An example list grants response: } */ +func TestIsAWSServicePrincipal(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + principal string + expected bool + }{ + { + name: "RDS service principal", + principal: "rds.eu-west-2.amazonaws.com", + expected: true, + }, + { + name: "DynamoDB service principal", + principal: "dynamodb.us-west-2.amazonaws.com", + expected: true, + }, + { + name: "EC2 service principal", + principal: "ec2.amazonaws.com", + expected: true, + }, + { + name: "China region service principal (aws-cn)", + principal: "rds.cn-north-1.amazonaws.com.cn", + expected: true, + }, + { + name: "EU partition service principal (aws-eu)", + principal: "rds.eu-central-1.amazonaws.eu", + expected: true, + }, + { + name: "ISO partition service principal (aws-iso)", + principal: "rds.us-iso-east-1.c2s.ic.gov", + expected: true, + }, + { + name: "ISO-B partition service principal (aws-iso-b)", + principal: "rds.us-isob-east-1.sc2s.sgov.gov", + expected: true, + }, + { + name: "IAM role ARN", + principal: "arn:aws:iam::123456789012:role/MyRole", + expected: false, + }, + { + name: "IAM user ARN", + principal: "arn:aws:iam::123456789012:user/MyUser", + expected: false, + }, + { + name: "Account root ARN", + principal: "arn:aws:iam::123456789012:root", + expected: false, + }, + { + name: "Random string", + principal: "not-a-principal", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := isAWSServicePrincipal(tt.principal) + if result != tt.expected { + t.Errorf("isAWSServicePrincipal(%q) = %v, expected %v", tt.principal, result, tt.expected) + } + }) + } +} + func TestGrantOutputMapper(t *testing.T) { output := &kms.ListGrantsOutput{ Grants: []types.GrantListEntry{ @@ -108,11 +185,60 @@ func TestGrantOutputMapper(t *testing.T) { tests.Execute(t, item) } +func TestGrantOutputMapperWithServicePrincipal(t *testing.T) { + // Test that service principals (like dynamodb.us-west-2.amazonaws.com) are + // properly skipped and don't cause errors or generate linked item queries + output := &kms.ListGrantsOutput{ + Grants: []types.GrantListEntry{ + { + Constraints: &types.GrantConstraints{ + EncryptionContextSubset: map[string]string{ + "aws:dynamodb:subscriberId": "123456789012", + "aws:dynamodb:tableName": "Services", + }, + }, + IssuingAccount: PtrString("arn:aws:iam::123456789012:root"), + Name: PtrString("8276b9a6-6cf0-46f1-b2f0-7993a7f8c89a"), + Operations: []types.GrantOperation{"Decrypt", "Encrypt"}, + GrantId: PtrString("1667b97d27cf748cf05b487217dd4179526c949d14fb3903858e25193253fe59"), + KeyId: PtrString("arn:aws:kms:us-west-2:123456789012:key/1234abcd-12ab-34cd-56ef-1234567890ab"), + // These are service principals, not ARNs - they should be skipped + RetiringPrincipal: PtrString("dynamodb.us-west-2.amazonaws.com"), + GranteePrincipal: PtrString("rds.eu-west-2.amazonaws.com"), + CreationDate: PtrTime(time.Now()), + }, + }, + } + + items, err := grantOutputMapper(context.Background(), nil, "foo", nil, output) + if err != nil { + t.Fatal(err) + } + + if len(items) != 1 { + t.Fatalf("expected 1 item, got %v", len(items)) + } + + item := items[0] + + // Should only have the kms-key link, not the service principals + if len(item.GetLinkedItemQueries()) != 1 { + t.Errorf("expected 1 linked item query (kms-key only), got %v", len(item.GetLinkedItemQueries())) + for i, liq := range item.GetLinkedItemQueries() { + t.Logf(" [%d] type=%s query=%s", i, liq.GetQuery().GetType(), liq.GetQuery().GetQuery()) + } + } + + if item.GetLinkedItemQueries()[0].GetQuery().GetType() != "kms-key" { + t.Errorf("expected linked item query to be kms-key, got %s", item.GetLinkedItemQueries()[0].GetQuery().GetType()) + } +} + func TestNewKMSGrantAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := kms.NewFromConfig(config) - adapter := NewKMSGrantAdapter(client, account, region, nil) + adapter := NewKMSGrantAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/kms-key-policy_test.go b/aws-source/adapters/kms-key-policy_test.go index c12a143e..96682d43 100644 --- a/aws-source/adapters/kms-key-policy_test.go +++ b/aws-source/adapters/kms-key-policy_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) /* @@ -108,7 +109,7 @@ func TestNewKMSKeyPolicyAdapter(t *testing.T) { client := kms.NewFromConfig(config) - adapter := NewKMSKeyPolicyAdapter(client, account, region, nil) + adapter := NewKMSKeyPolicyAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/kms-key_test.go b/aws-source/adapters/kms-key_test.go index 1462017e..5d88422c 100644 --- a/aws-source/adapters/kms-key_test.go +++ b/aws-source/adapters/kms-key_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/overmindtech/cli/sdpcache" ) type kmsTestClient struct{} @@ -87,11 +88,11 @@ func TestKMSGetFunc(t *testing.T) { } func TestNewKMSKeyAdapter(t *testing.T) { - t.Skip("This test is currently failing due to a key that none of us can read, even with admin permissions. I think we will need to speak with AWS support to work out how to delete it", nil) + t.Skip("This test is currently failing due to a key that none of us can read, even with admin permissions. I think we will need to speak with AWS support to work out how to delete it") config, account, region := GetAutoConfig(t) client := kms.NewFromConfig(config) - adapter := NewKMSKeyAdapter(client, account, region, nil) + adapter := NewKMSKeyAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/lambda-event-source-mapping.go b/aws-source/adapters/lambda-event-source-mapping.go index edbdfed6..5a53be8a 100644 --- a/aws-source/adapters/lambda-event-source-mapping.go +++ b/aws-source/adapters/lambda-event-source-mapping.go @@ -184,7 +184,7 @@ func NewLambdaEventSourceMappingAdapter(client lambdaEventSourceMappingClient, a AccountID: accountID, Region: region, AdapterMetadata: lambdaEventSourceMappingAdapterMetadata, - cache: cache, + cache: cache, GetFunc: func(ctx context.Context, client lambdaEventSourceMappingClient, scope, query string) (*types.EventSourceMappingConfiguration, error) { out, err := client.GetEventSourceMapping(ctx, &lambda.GetEventSourceMappingInput{ UUID: &query, diff --git a/aws-source/adapters/lambda-event-source-mapping_test.go b/aws-source/adapters/lambda-event-source-mapping_test.go index c03895bd..e671c50d 100644 --- a/aws-source/adapters/lambda-event-source-mapping_test.go +++ b/aws-source/adapters/lambda-event-source-mapping_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/lambda" "github.com/aws/aws-sdk-go-v2/service/lambda/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) type TestLambdaEventSourceMappingClient struct{} @@ -88,7 +89,7 @@ func stringPtr(s string) *string { } func TestLambdaEventSourceMappingAdapter(t *testing.T) { - adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", nil) + adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) // Test adapter metadata if adapter.Type() != "lambda-event-source-mapping" { @@ -110,7 +111,7 @@ func TestLambdaEventSourceMappingAdapter(t *testing.T) { } func TestLambdaEventSourceMappingGetFunc(t *testing.T) { - adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", nil) + adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) // Test getting existing event source mapping item, err := adapter.Get(context.Background(), "123456789012.us-east-1", "test-uuid-1", false) @@ -145,7 +146,7 @@ func TestLambdaEventSourceMappingGetFunc(t *testing.T) { } func TestLambdaEventSourceMappingItemMapper(t *testing.T) { - adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", nil) + adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) // Test mapping with SQS event source awsItem := &types.EventSourceMappingConfiguration{ @@ -208,7 +209,7 @@ func TestLambdaEventSourceMappingItemMapper(t *testing.T) { } func TestLambdaEventSourceMappingItemMapperWithDynamoDB(t *testing.T) { - adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", nil) + adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) // Test mapping with DynamoDB event source awsItem := &types.EventSourceMappingConfiguration{ @@ -241,7 +242,7 @@ func TestLambdaEventSourceMappingItemMapperWithDynamoDB(t *testing.T) { } func TestLambdaEventSourceMappingItemMapperWithRDS(t *testing.T) { - adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", nil) + adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) // Test mapping with RDS/DocumentDB event source awsItem := &types.EventSourceMappingConfiguration{ @@ -274,7 +275,7 @@ func TestLambdaEventSourceMappingItemMapperWithRDS(t *testing.T) { } func TestLambdaEventSourceMappingSearchByEventSourceARN(t *testing.T) { - adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", nil) + adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) // Test search by SQS queue ARN sqsQueueARN := "arn:aws:sqs:us-east-1:123456789012:test-queue" @@ -294,7 +295,7 @@ func TestLambdaEventSourceMappingSearchByEventSourceARN(t *testing.T) { } func TestLambdaEventSourceMappingSearchWrongScope(t *testing.T) { - adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", nil) + adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) // Test search with wrong scope _, err := adapter.Search(context.Background(), "wrong-scope", "arn:aws:sqs:us-east-1:123456789012:test-queue", false) @@ -304,7 +305,7 @@ func TestLambdaEventSourceMappingSearchWrongScope(t *testing.T) { } func TestLambdaEventSourceMappingAdapterList(t *testing.T) { - adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", nil) + adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) // Test List items, err := adapter.List(context.Background(), "123456789012.us-east-1", false) @@ -338,7 +339,7 @@ func TestLambdaEventSourceMappingAdapterList(t *testing.T) { } func TestLambdaEventSourceMappingAdapterListWrongScope(t *testing.T) { - adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", nil) + adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) // Test List with wrong scope _, err := adapter.List(context.Background(), "wrong-scope", false) @@ -348,7 +349,7 @@ func TestLambdaEventSourceMappingAdapterListWrongScope(t *testing.T) { } func TestLambdaEventSourceMappingAdapterIntegration(t *testing.T) { - adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", nil) + adapter := NewLambdaEventSourceMappingAdapter(&TestLambdaEventSourceMappingClient{}, "123456789012", "us-east-1", sdpcache.NewNoOpCache()) // Test Get item, err := adapter.Get(context.Background(), "123456789012.us-east-1", "test-uuid-1", false) diff --git a/aws-source/adapters/lambda-function_test.go b/aws-source/adapters/lambda-function_test.go index a2ce6a5b..f6613dfe 100644 --- a/aws-source/adapters/lambda-function_test.go +++ b/aws-source/adapters/lambda-function_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/lambda" "github.com/aws/aws-sdk-go-v2/service/lambda/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) var testFuncConfig = &types.FunctionConfiguration{ @@ -375,7 +376,7 @@ func TestGetEventLinkedItem(t *testing.T) { func TestNewLambdaFunctionAdapter(t *testing.T) { client, account, region := lambdaGetAutoConfig(t) - adapter := NewLambdaFunctionAdapter(client, account, region, nil) + adapter := NewLambdaFunctionAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/lambda-layer-version_test.go b/aws-source/adapters/lambda-layer-version_test.go index cd3537db..09c06411 100644 --- a/aws-source/adapters/lambda-layer-version_test.go +++ b/aws-source/adapters/lambda-layer-version_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/lambda" "github.com/aws/aws-sdk-go-v2/service/lambda/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestLayerVersionGetInputMapper(t *testing.T) { @@ -115,7 +116,7 @@ func TestLayerVersionGetFunc(t *testing.T) { func TestNewLambdaLayerVersionAdapter(t *testing.T) { client, account, region := lambdaGetAutoConfig(t) - adapter := NewLambdaLayerVersionAdapter(client, account, region, nil) + adapter := NewLambdaLayerVersionAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/lambda-layer_test.go b/aws-source/adapters/lambda-layer_test.go index 9b409353..71713317 100644 --- a/aws-source/adapters/lambda-layer_test.go +++ b/aws-source/adapters/lambda-layer_test.go @@ -6,6 +6,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/lambda/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestLayerItemMapper(t *testing.T) { @@ -53,7 +54,7 @@ func TestLayerItemMapper(t *testing.T) { func TestNewLambdaLayerAdapter(t *testing.T) { client, account, region := lambdaGetAutoConfig(t) - adapter := NewLambdaLayerAdapter(client, account, region, nil) + adapter := NewLambdaLayerAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/networkmanager-connection_test.go b/aws-source/adapters/networkmanager-connection_test.go index 2f4e392e..c873431c 100644 --- a/aws-source/adapters/networkmanager-connection_test.go +++ b/aws-source/adapters/networkmanager-connection_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/networkmanager/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestConnectionOutputMapper(t *testing.T) { @@ -90,7 +91,7 @@ func TestConnectionOutputMapper(t *testing.T) { } func TestConnectionInputMapperSearch(t *testing.T) { - adapter := NewNetworkManagerConnectionAdapter(&networkmanager.Client{}, "123456789012", nil) + adapter := NewNetworkManagerConnectionAdapter(&networkmanager.Client{}, "123456789012", sdpcache.NewNoOpCache()) tests := []struct { name string diff --git a/aws-source/adapters/networkmanager-device_test.go b/aws-source/adapters/networkmanager-device_test.go index 64dafdc0..76f432e1 100644 --- a/aws-source/adapters/networkmanager-device_test.go +++ b/aws-source/adapters/networkmanager-device_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/networkmanager/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestDeviceOutputMapper(t *testing.T) { @@ -88,7 +89,7 @@ func TestDeviceOutputMapper(t *testing.T) { } func TestDeviceInputMapperSearch(t *testing.T) { - adapter := NewNetworkManagerDeviceAdapter(&networkmanager.Client{}, "123456789012", nil) + adapter := NewNetworkManagerDeviceAdapter(&networkmanager.Client{}, "123456789012", sdpcache.NewNoOpCache()) tests := []struct { name string diff --git a/aws-source/adapters/networkmanager-link_test.go b/aws-source/adapters/networkmanager-link_test.go index 770c62ef..4f3c3e60 100644 --- a/aws-source/adapters/networkmanager-link_test.go +++ b/aws-source/adapters/networkmanager-link_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/networkmanager/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestLinkOutputMapper(t *testing.T) { @@ -82,7 +83,7 @@ func TestLinkOutputMapper(t *testing.T) { } func TestLinkInputMapperSearch(t *testing.T) { - adapter := NewNetworkManagerLinkAdapter(&networkmanager.Client{}, "123456789012", nil) + adapter := NewNetworkManagerLinkAdapter(&networkmanager.Client{}, "123456789012", sdpcache.NewNoOpCache()) tests := []struct { name string diff --git a/aws-source/adapters/networkmanager-site_test.go b/aws-source/adapters/networkmanager-site_test.go index ef8b5cde..94299f3f 100644 --- a/aws-source/adapters/networkmanager-site_test.go +++ b/aws-source/adapters/networkmanager-site_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/networkmanager/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestSiteOutputMapper(t *testing.T) { @@ -74,7 +75,7 @@ func TestSiteOutputMapper(t *testing.T) { } func TestSiteInputMapperSearch(t *testing.T) { - adapter := NewNetworkManagerSiteAdapter(&networkmanager.Client{}, "123456789012", nil) + adapter := NewNetworkManagerSiteAdapter(&networkmanager.Client{}, "123456789012", sdpcache.NewNoOpCache()) tests := []struct { name string diff --git a/aws-source/adapters/rds-db-cluster-parameter-group_test.go b/aws-source/adapters/rds-db-cluster-parameter-group_test.go index 31192c3e..6d89db4d 100644 --- a/aws-source/adapters/rds-db-cluster-parameter-group_test.go +++ b/aws-source/adapters/rds-db-cluster-parameter-group_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/service/rds/types" + "github.com/overmindtech/cli/sdpcache" ) func TestDBClusterParameterGroupOutputMapper(t *testing.T) { @@ -86,7 +87,7 @@ func TestDBClusterParameterGroupOutputMapper(t *testing.T) { func TestNewRDSDBClusterParameterGroupAdapter(t *testing.T) { client, account, region := rdsGetAutoConfig(t) - adapter := NewRDSDBClusterParameterGroupAdapter(client, account, region, nil) + adapter := NewRDSDBClusterParameterGroupAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/rds-db-cluster_test.go b/aws-source/adapters/rds-db-cluster_test.go index 41460351..c35d57e9 100644 --- a/aws-source/adapters/rds-db-cluster_test.go +++ b/aws-source/adapters/rds-db-cluster_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestDBClusterOutputMapper(t *testing.T) { @@ -253,7 +254,7 @@ func TestDBClusterOutputMapper(t *testing.T) { func TestNewRDSDBClusterAdapter(t *testing.T) { client, account, region := rdsGetAutoConfig(t) - adapter := NewRDSDBClusterAdapter(client, account, region, nil) + adapter := NewRDSDBClusterAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/rds-db-instance_test.go b/aws-source/adapters/rds-db-instance_test.go index 3d7fbfd5..6fe35b5b 100644 --- a/aws-source/adapters/rds-db-instance_test.go +++ b/aws-source/adapters/rds-db-instance_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestDBInstanceOutputMapper(t *testing.T) { @@ -311,7 +312,7 @@ func TestDBInstanceOutputMapper(t *testing.T) { func TestNewRDSDBInstanceAdapter(t *testing.T) { client, account, region := rdsGetAutoConfig(t) - adapter := NewRDSDBInstanceAdapter(client, account, region, nil) + adapter := NewRDSDBInstanceAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/rds-db-parameter-group_test.go b/aws-source/adapters/rds-db-parameter-group_test.go index bd654fd9..3c3d10e0 100644 --- a/aws-source/adapters/rds-db-parameter-group_test.go +++ b/aws-source/adapters/rds-db-parameter-group_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/service/rds/types" + "github.com/overmindtech/cli/sdpcache" ) func TestDBParameterGroupOutputMapper(t *testing.T) { @@ -74,7 +75,7 @@ func TestDBParameterGroupOutputMapper(t *testing.T) { func TestNewRDSDBParameterGroupAdapter(t *testing.T) { client, account, region := rdsGetAutoConfig(t) - adapter := NewRDSDBParameterGroupAdapter(client, account, region, nil) + adapter := NewRDSDBParameterGroupAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/rds-db-subnet-group_test.go b/aws-source/adapters/rds-db-subnet-group_test.go index f2156114..4c1ac8e8 100644 --- a/aws-source/adapters/rds-db-subnet-group_test.go +++ b/aws-source/adapters/rds-db-subnet-group_test.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestDBSubnetGroupOutputMapper(t *testing.T) { @@ -85,7 +86,7 @@ func TestDBSubnetGroupOutputMapper(t *testing.T) { func TestNewRDSDBSubnetGroupAdapter(t *testing.T) { client, account, region := rdsGetAutoConfig(t) - adapter := NewRDSDBSubnetGroupAdapter(client, account, region, nil) + adapter := NewRDSDBSubnetGroupAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/route53-health-check_test.go b/aws-source/adapters/route53-health-check_test.go index 124e0181..69024a9a 100644 --- a/aws-source/adapters/route53-health-check_test.go +++ b/aws-source/adapters/route53-health-check_test.go @@ -6,6 +6,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/route53/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestHealthCheckItemMapper(t *testing.T) { @@ -72,7 +73,7 @@ func TestHealthCheckItemMapper(t *testing.T) { func TestNewRoute53HealthCheckAdapter(t *testing.T) { client, account, region := route53GetAutoConfig(t) - adapter := NewRoute53HealthCheckAdapter(client, account, region, nil) + adapter := NewRoute53HealthCheckAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/route53-hosted-zone.go b/aws-source/adapters/route53-hosted-zone.go index 4520076e..bf6f855e 100644 --- a/aws-source/adapters/route53-hosted-zone.go +++ b/aws-source/adapters/route53-hosted-zone.go @@ -82,7 +82,7 @@ func NewRoute53HostedZoneAdapter(client *route53.Client, accountID string, regio ListFunc: hostedZoneListFunc, ItemMapper: hostedZoneItemMapper, AdapterMetadata: hostedZoneAdapterMetadata, - cache: cache, + cache: cache, ListTagsFunc: func(ctx context.Context, hz *types.HostedZone, c *route53.Client) (map[string]string, error) { if hz.Id == nil { return nil, nil diff --git a/aws-source/adapters/route53-hosted-zone_test.go b/aws-source/adapters/route53-hosted-zone_test.go index 0230935b..322c2da5 100644 --- a/aws-source/adapters/route53-hosted-zone_test.go +++ b/aws-source/adapters/route53-hosted-zone_test.go @@ -6,6 +6,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/route53/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestHostedZoneItemMapper(t *testing.T) { @@ -49,7 +50,7 @@ func TestHostedZoneItemMapper(t *testing.T) { func TestNewRoute53HostedZoneAdapter(t *testing.T) { client, account, region := route53GetAutoConfig(t) - adapter := NewRoute53HostedZoneAdapter(client, account, region, nil) + adapter := NewRoute53HostedZoneAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/route53-resource-record-set.go b/aws-source/adapters/route53-resource-record-set.go index 61eb5a24..d6ba1d8a 100644 --- a/aws-source/adapters/route53-resource-record-set.go +++ b/aws-source/adapters/route53-resource-record-set.go @@ -199,7 +199,7 @@ func NewRoute53ResourceRecordSetAdapter(client *route53.Client, accountID string ItemMapper: resourceRecordSetItemMapper, SearchFunc: resourceRecordSetSearchFunc, AdapterMetadata: resourceRecordSetAdapterMetadata, - cache: cache} + cache: cache} } var resourceRecordSetAdapterMetadata = Metadata.Register(&sdp.AdapterMetadata{ diff --git a/aws-source/adapters/route53-resource-record-set_test.go b/aws-source/adapters/route53-resource-record-set_test.go index aa9b94fc..6719e02e 100644 --- a/aws-source/adapters/route53-resource-record-set_test.go +++ b/aws-source/adapters/route53-resource-record-set_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/route53/types" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestResourceRecordSetItemMapper(t *testing.T) { @@ -208,7 +209,7 @@ func TestConstructRecordFQDN(t *testing.T) { func TestNewRoute53ResourceRecordSetAdapter(t *testing.T) { client, account, region := route53GetAutoConfig(t) - zoneSource := NewRoute53HostedZoneAdapter(client, account, region, nil) + zoneSource := NewRoute53HostedZoneAdapter(client, account, region, sdpcache.NewNoOpCache()) zones, err := zoneSource.List(context.Background(), zoneSource.Scopes()[0], true) if err != nil { @@ -219,7 +220,7 @@ func TestNewRoute53ResourceRecordSetAdapter(t *testing.T) { t.Skip("no zones found") } - adapter := NewRoute53ResourceRecordSetAdapter(client, account, region, nil) + adapter := NewRoute53ResourceRecordSetAdapter(client, account, region, sdpcache.NewNoOpCache()) search := zones[0].UniqueAttributeValue() test := E2ETest{ diff --git a/aws-source/adapters/s3.go b/aws-source/adapters/s3.go index fd7338f0..46599ed2 100644 --- a/aws-source/adapters/s3.go +++ b/aws-source/adapters/s3.go @@ -2,6 +2,7 @@ package adapters import ( "context" + "errors" "fmt" "sync" "time" @@ -17,11 +18,6 @@ import ( const CacheDuration = 10 * time.Minute -var ( - noOpCacheS3Once sync.Once - noOpCacheS3 sdpcache.Cache -) - // NewS3Source Creates a new S3 adapter func NewS3Adapter(config aws.Config, accountID string, cache sdpcache.Cache) *S3Source { return &S3Source{ @@ -89,16 +85,6 @@ type S3Source struct { cache sdpcache.Cache // The cache for this adapter (set during creation, can be nil for tests) } -func (s *S3Source) Cache() sdpcache.Cache { - if s.cache == nil { - noOpCacheS3Once.Do(func() { - noOpCacheS3 = sdpcache.NewNoOpCache() - }) - return noOpCacheS3 - } - return s.cache -} - func (s *S3Source) Client() *s3.Client { s.clientMutex.Lock() defer s.clientMutex.Unlock() @@ -202,7 +188,7 @@ func (s *S3Source) Get(ctx context.Context, scope string, query string, ignoreCa } } - return getImpl(ctx, s.Cache(), s.Client(), scope, query, ignoreCache) + return getImpl(ctx, s.cache, s.Client(), scope, query, ignoreCache) } func getImpl(ctx context.Context, cache sdpcache.Cache, client S3Client, scope string, query string, ignoreCache bool) (*sdp.Item, error) { @@ -236,7 +222,13 @@ func getImpl(ctx context.Context, cache sdpcache.Cache, client S3Client, scope s if err != nil { err = WrapAWSError(err) - cache.StoreError(ctx, err, CacheDuration, ck) + var queryErr *sdp.QueryError + if errors.As(err, &queryErr) { + // Cache not-found errors and other non-retryable errors + if queryErr.GetErrorType() == sdp.QueryError_NOTFOUND || !CanRetry(queryErr) { + cache.StoreError(ctx, err, CacheDuration, ck) + } + } return nil, err } @@ -607,7 +599,7 @@ func (s *S3Source) List(ctx context.Context, scope string, ignoreCache bool) ([] } } - return listImpl(ctx, s.Cache(), s.Client(), scope, ignoreCache) + return listImpl(ctx, s.cache, s.Client(), scope, ignoreCache) } func listImpl(ctx context.Context, cache sdpcache.Cache, client S3Client, scope string, ignoreCache bool) ([]*sdp.Item, error) { @@ -619,6 +611,10 @@ func listImpl(ctx context.Context, cache sdpcache.Cache, client S3Client, scope cacheHit, ck, cachedItems, qErr, done := cache.Lookup(ctx, "aws-s3-adapter", sdp.QueryMethod_LIST, scope, "s3-bucket", "", ignoreCache) defer done() if qErr != nil { + // For better semantics, convert cached NOTFOUND into empty result + if qErr.GetErrorType() == sdp.QueryError_NOTFOUND { + return []*sdp.Item{}, nil + } return nil, qErr } if cacheHit { @@ -639,14 +635,33 @@ func listImpl(ctx context.Context, cache sdpcache.Cache, client S3Client, scope return nil, err } + hadErrors := false for _, bucket := range buckets.Buckets { item, err := getImpl(ctx, cache, client, scope, *bucket.Name, ignoreCache) if err != nil { + hadErrors = true continue } - items = append(items, item) + if item != nil { + items = append(items, item) + } + } + + // Cache not-found only when no buckets were returned AND no errors occurred + // If we had errors, buckets may exist but we couldn't fetch them + if len(items) == 0 && !hadErrors && len(buckets.Buckets) == 0 { + notFoundErr := &sdp.QueryError{ + ErrorType: sdp.QueryError_NOTFOUND, + ErrorString: "no s3-bucket found in scope " + scope, + Scope: scope, + SourceName: "aws-s3-adapter", + ItemType: "s3-bucket", + ResponderName: "aws-s3-adapter", + } + cache.StoreError(ctx, notFoundErr, CacheDuration, ck) + return items, nil } for _, item := range items { @@ -665,7 +680,7 @@ func (s *S3Source) Search(ctx context.Context, scope string, query string, ignor } } - return searchImpl(ctx, s.Cache(), s.Client(), scope, query, ignoreCache) + return searchImpl(ctx, s.cache, s.Client(), scope, query, ignoreCache) } func searchImpl(ctx context.Context, cache sdpcache.Cache, client S3Client, scope string, query string, ignoreCache bool) ([]*sdp.Item, error) { @@ -695,7 +710,10 @@ func searchImpl(ctx context.Context, cache sdpcache.Cache, client S3Client, scop return nil, err } - return []*sdp.Item{item}, nil + if item != nil { + return []*sdp.Item{item}, nil + } + return []*sdp.Item{}, nil } // Weight Returns the priority weighting of items returned by this adapter. diff --git a/aws-source/adapters/s3_test.go b/aws-source/adapters/s3_test.go index d687779a..b3e23b51 100644 --- a/aws-source/adapters/s3_test.go +++ b/aws-source/adapters/s3_test.go @@ -3,18 +3,18 @@ package adapters import ( "context" "errors" + "strings" + "testing" + "time" + "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/overmindtech/cli/sdp-go" "github.com/overmindtech/cli/sdpcache" - "strings" - "testing" - "time" ) func TestS3SearchImpl(t *testing.T) { - cache := sdpcache.NewCache(t.Context()) - + cache := sdpcache.NewNoOpCache() t.Run("with S3 bucket ARN format (empty account ID and region)", func(t *testing.T) { // This test verifies that S3 bucket ARNs with empty account ID and region work correctly // Format: arn:aws:s3:::bucket-name @@ -61,7 +61,7 @@ func TestS3SearchImpl(t *testing.T) { } func TestS3ListImpl(t *testing.T) { - cache := sdpcache.NewCache(t.Context()) + cache := sdpcache.NewNoOpCache() items, err := listImpl(context.Background(), cache, TestS3Client{}, "foo", false) if err != nil { @@ -73,7 +73,7 @@ func TestS3ListImpl(t *testing.T) { } func TestS3GetImpl(t *testing.T) { - cache := sdpcache.NewCache(t.Context()) + cache := sdpcache.NewNoOpCache() item, err := getImpl(context.Background(), cache, TestS3Client{}, "foo", "bar", false) if err != nil { @@ -129,7 +129,7 @@ func TestS3GetImpl(t *testing.T) { } func TestS3SourceCaching(t *testing.T) { - cache := sdpcache.NewCache(t.Context()) + cache := sdpcache.NewMemoryCache() first, err := getImpl(context.Background(), cache, TestS3Client{}, "foo", "bar", false) if err != nil { t.Fatal(err) @@ -664,7 +664,7 @@ func (t TestS3FailClient) PutObject(ctx context.Context, params *s3.PutObjectInp func TestNewS3Adapter(t *testing.T) { config, account, _ := GetAutoConfig(t) - adapter := NewS3Adapter(config, account, nil) + adapter := NewS3Adapter(config, account, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, @@ -683,7 +683,7 @@ func TestS3SearchWithARNFormat(t *testing.T) { // CURRENT BEHAVIOR: Get works, Search fails with NOSCOPE error - THIS IS THE BUG config, account, _ := GetAutoConfig(t) - adapter := NewS3Adapter(config, account, nil) + adapter := NewS3Adapter(config, account, sdpcache.NewNoOpCache()) scope := adapter.Scopes()[0] bucketName := "harness-sample-three-qa-us-west-2-20251022151048279100000001" diff --git a/aws-source/adapters/sns-data-protection-policy_test.go b/aws-source/adapters/sns-data-protection-policy_test.go index afbed01b..a1c75dca 100644 --- a/aws-source/adapters/sns-data-protection-policy_test.go +++ b/aws-source/adapters/sns-data-protection-policy_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/service/sns" + "github.com/overmindtech/cli/sdpcache" ) type mockDataProtectionPolicyClient struct{} @@ -36,7 +37,7 @@ func TestNewSNSDataProtectionPolicyAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := sns.NewFromConfig(config) - adapter := NewSNSDataProtectionPolicyAdapter(client, account, region, nil) + adapter := NewSNSDataProtectionPolicyAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/sns-endpoint_test.go b/aws-source/adapters/sns-endpoint_test.go index b7378b74..e8ddd5b2 100644 --- a/aws-source/adapters/sns-endpoint_test.go +++ b/aws-source/adapters/sns-endpoint_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sns" "github.com/aws/aws-sdk-go-v2/service/sns/types" + "github.com/overmindtech/cli/sdpcache" ) type mockEndpointClient struct{} @@ -58,7 +59,7 @@ func TestNewSNSEndpointAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := sns.NewFromConfig(config) - adapter := NewSNSEndpointAdapter(client, account, region, nil) + adapter := NewSNSEndpointAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/sns-platform-application_test.go b/aws-source/adapters/sns-platform-application_test.go index ca859c9f..3bc4d2a3 100644 --- a/aws-source/adapters/sns-platform-application_test.go +++ b/aws-source/adapters/sns-platform-application_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sns" "github.com/aws/aws-sdk-go-v2/service/sns/types" + "github.com/overmindtech/cli/sdpcache" ) type mockPlatformApplicationClient struct{} @@ -70,7 +71,7 @@ func TestNewSNSPlatformApplicationAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := sns.NewFromConfig(config) - adapter := NewSNSPlatformApplicationAdapter(client, account, region, nil) + adapter := NewSNSPlatformApplicationAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/sns-subscription_test.go b/aws-source/adapters/sns-subscription_test.go index 8c607e7f..f3b46f59 100644 --- a/aws-source/adapters/sns-subscription_test.go +++ b/aws-source/adapters/sns-subscription_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sns" "github.com/aws/aws-sdk-go-v2/service/sns/types" + "github.com/overmindtech/cli/sdpcache" ) type snsTestClient struct{} @@ -67,7 +68,7 @@ func TestNewSNSSubscriptionAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := sns.NewFromConfig(config) - adapter := NewSNSSubscriptionAdapter(client, account, region, nil) + adapter := NewSNSSubscriptionAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/sns-topic_test.go b/aws-source/adapters/sns-topic_test.go index 9615a976..484fd227 100644 --- a/aws-source/adapters/sns-topic_test.go +++ b/aws-source/adapters/sns-topic_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sns" "github.com/aws/aws-sdk-go-v2/service/sns/types" + "github.com/overmindtech/cli/sdpcache" ) type testTopicClient struct{} @@ -64,7 +65,7 @@ func TestNewSNSTopicAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := sns.NewFromConfig(config) - adapter := NewSNSTopicAdapter(client, account, region, nil) + adapter := NewSNSTopicAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/sqs-queue_test.go b/aws-source/adapters/sqs-queue_test.go index c89f315b..54672448 100644 --- a/aws-source/adapters/sqs-queue_test.go +++ b/aws-source/adapters/sqs-queue_test.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) type testClient struct{} @@ -164,7 +165,7 @@ func TestNewQueueAdapter(t *testing.T) { config, account, region := GetAutoConfig(t) client := sqs.NewFromConfig(config) - adapter := NewSQSQueueAdapter(client, account, region, nil) + adapter := NewSQSQueueAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/adapters/ssm-parameter.go b/aws-source/adapters/ssm-parameter.go index d26a696b..ea8b5a32 100644 --- a/aws-source/adapters/ssm-parameter.go +++ b/aws-source/adapters/ssm-parameter.go @@ -226,7 +226,7 @@ func NewSSMParameterAdapter(client ssmClient, accountID string, region string, c Region: region, ItemType: "ssm-parameter", AdapterMetadata: ssmParameterAdapterMetadata, - cache: cache, + cache: cache, InputMapperGet: func(scope, query string) (*ssm.DescribeParametersInput, error) { return &ssm.DescribeParametersInput{ ParameterFilters: []types.ParameterStringFilter{ diff --git a/aws-source/adapters/ssm-parameter_test.go b/aws-source/adapters/ssm-parameter_test.go index 1e592872..2c01d3b7 100644 --- a/aws-source/adapters/ssm-parameter_test.go +++ b/aws-source/adapters/ssm-parameter_test.go @@ -9,6 +9,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/overmindtech/cli/discovery" + "github.com/overmindtech/cli/sdpcache" ) type mockSSMClient struct{} @@ -68,7 +69,7 @@ func (m *mockSSMClient) GetParameter(ctx context.Context, input *ssm.GetParamete } func TestSSMParameterAdapter(t *testing.T) { - adapter := NewSSMParameterAdapter(&mockSSMClient{}, "123456789", "us-east-1", nil) + adapter := NewSSMParameterAdapter(&mockSSMClient{}, "123456789", "us-east-1", sdpcache.NewNoOpCache()) t.Run("Get", func(t *testing.T) { item, err := adapter.Get(context.Background(), "123456789.us-east-1", "test", false) @@ -122,7 +123,7 @@ func TestSSMParameterAdapterE2E(t *testing.T) { config, account, region := GetAutoConfig(t) client := ssm.NewFromConfig(config) - adapter := NewSSMParameterAdapter(client, account, region, nil) + adapter := NewSSMParameterAdapter(client, account, region, sdpcache.NewNoOpCache()) test := E2ETest{ Adapter: adapter, diff --git a/aws-source/cmd/root.go b/aws-source/cmd/root.go index 35252367..0a9a4f11 100644 --- a/aws-source/cmd/root.go +++ b/aws-source/cmd/root.go @@ -23,99 +23,74 @@ var cfgFile string // rootCmd represents the base command when called without any subcommands var rootCmd = &cobra.Command{ - Use: "aws-source", - Short: "Remote primary source for AWS", + Use: "aws-source", + Short: "Remote primary source for AWS", + SilenceUsage: true, Long: `This sources looks for AWS resources in your account. `, - Run: func(cmd *cobra.Command, args []string) { - ctx := context.Background() + RunE: func(cmd *cobra.Command, args []string) error { + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() defer tracing.LogRecoverToReturn(ctx, "aws-source.root") healthCheckPort := viper.GetInt("health-check-port") - awsAuthConfig := proc.AwsAuthConfig{ - Strategy: viper.GetString("aws-access-strategy"), - AccessKeyID: viper.GetString("aws-access-key-id"), - SecretAccessKey: viper.GetString("aws-secret-access-key"), - ExternalID: viper.GetString("aws-external-id"), - TargetRoleARN: viper.GetString("aws-target-role-arn"), - Profile: viper.GetString("aws-profile"), - AutoConfig: viper.GetBool("auto-config"), - } - - err := viper.UnmarshalKey("aws-regions", &awsAuthConfig.Regions) - if err != nil { - log.WithError(err).Fatal("Could not parse aws-regions") - } - engineConfig, err := discovery.EngineConfigFromViper("aws", tracing.Version()) if err != nil { - log.WithError(err).Fatal("Could not create engine config") + log.WithError(err).Error("Could not create engine config") + return fmt.Errorf("could not create engine config: %w", err) } - log.WithFields(log.Fields{ - "aws-regions": awsAuthConfig.Regions, - "aws-access-strategy": awsAuthConfig.Strategy, - "aws-external-id": awsAuthConfig.ExternalID, - "aws-target-role-arn": awsAuthConfig.TargetRoleARN, - "aws-profile": awsAuthConfig.Profile, - "auto-config": awsAuthConfig.AutoConfig, - "health-check-port": healthCheckPort, - }).Info("Got config") - - err = engineConfig.CreateClients() + // Create a basic engine first so we can serve health probes and heartbeats even if init fails + e, err := discovery.NewEngine(engineConfig) if err != nil { sentry.CaptureException(err) - log.WithError(err).Fatal("could not auth create clients") - } - - rateLimitContext, rateLimitCancel := context.WithCancel(context.Background()) - defer rateLimitCancel() - - configs, err := proc.CreateAWSConfigs(awsAuthConfig) - if err != nil { - log.WithError(err).Fatal("Could not create AWS configs") - } - - // Initialize the engine - e, err := proc.InitializeAwsSourceEngine( - rateLimitContext, - engineConfig, - 999_999, // Very high max retries as it'll time out after 15min anyway - configs..., - ) - if err != nil { - log.WithError(err).Fatal("Could not initialize AWS source") + log.WithError(err).Error("Could not create engine") + return fmt.Errorf("could not create engine: %w", err) } + // Serve health probes before initialization so they're available even on failure e.ServeHealthProbes(healthCheckPort) + // Start the engine (NATS connection) before adapter init so heartbeats work err = e.Start(ctx) if err != nil { - log.WithFields(log.Fields{ - "error": err, - }).Fatal("Could not start engine") + sentry.CaptureException(err) + log.WithError(err).Error("Could not start engine") + return fmt.Errorf("could not start engine: %w", err) } - sigs := make(chan os.Signal, 1) - - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + // Config validation (permanent errors — no retry, just idle with error) + configs, cfgErr := proc.ConfigFromViper() + if cfgErr != nil { + log.WithError(cfgErr).Error("AWS source config error - pod will stay running with error status") + e.SetInitError(cfgErr) + sentry.CaptureException(cfgErr) + } else { + log.WithFields(log.Fields{ + "aws-regions": len(configs), + "health-check-port": healthCheckPort, + }).Info("Got config") + // Adapter init (retryable errors — backoff capped at 5 min) + e.InitialiseAdapters(ctx, func(ctx context.Context) error { + return proc.InitializeAwsSourceAdapters(ctx, e, configs...) + }) + } - <-sigs + <-ctx.Done() log.Info("Stopping engine") err = e.Stop() - if err != nil { log.WithFields(log.Fields{ "error": err, }).Error("Could not stop engine") - os.Exit(1) + return fmt.Errorf("could not stop engine: %w", err) } log.Info("Stopped") - os.Exit(0) + return nil }, } @@ -165,7 +140,7 @@ func init() { cobra.CheckErr(viper.BindPFlags(rootCmd.PersistentFlags())) // Run this before we do anything to set up the loglevel - rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { if lvl, err := log.ParseLevel(logLevel); err == nil { log.SetLevel(lvl) } else { @@ -178,23 +153,29 @@ func init() { log.AddHook(TerminationLogHook{}) // Bind flags that haven't been set to the values from viper of we have them + var bindErr error cmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { // Bind the flag to viper only if it has a non-empty default if f.DefValue != "" || f.Changed { - err := viper.BindPFlag(f.Name, f) - if err != nil { - log.WithError(err).Fatal("could not bind flag to viper") + if err := viper.BindPFlag(f.Name, f); err != nil { + bindErr = err } } }) + if bindErr != nil { + log.WithError(bindErr).Error("could not bind flag to viper") + return fmt.Errorf("could not bind flag to viper: %w", bindErr) + } if viper.GetBool("json-log") { logging.ConfigureLogrusJSON(log.StandardLogger()) } if err := tracing.InitTracerWithUpstreams("aws-source", viper.GetString("honeycomb-api-key"), viper.GetString("sentry-dsn")); err != nil { - log.Fatal(err) + log.WithError(err).Error("could not init tracer") + return fmt.Errorf("could not init tracer: %w", err) } + return nil } // shut down tracing at the end of the process rootCmd.PersistentPostRun = func(cmd *cobra.Command, args []string) { @@ -227,8 +208,7 @@ func (t TerminationLogHook) Levels() []log.Level { func (t TerminationLogHook) Fire(e *log.Entry) error { // shutdown tracing first to ensure all spans are flushed tracing.ShutdownTracer(context.Background()) - tLog, err := os.OpenFile("/dev/termination-log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - + tLog, err := os.OpenFile("/dev/termination-log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return err } diff --git a/aws-source/proc/proc.go b/aws-source/proc/proc.go index c8ab3c7a..c049e7c4 100644 --- a/aws-source/proc/proc.go +++ b/aws-source/proc/proc.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "strings" + "sync" "sync/atomic" "time" @@ -31,7 +32,7 @@ import ( awssns "github.com/aws/aws-sdk-go-v2/service/sns" awssqs "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/ssm" - "github.com/cenkalti/backoff/v5" + "github.com/aws/smithy-go" "github.com/sourcegraph/conc/pool" "github.com/aws/aws-sdk-go-v2/aws" @@ -43,6 +44,7 @@ import ( "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdpcache" log "github.com/sirupsen/logrus" + "github.com/spf13/viper" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) @@ -62,16 +64,56 @@ type AwsAuthConfig struct { Regions []string } +// ConfigFromViper reads AWS configuration from viper, parses regions, and creates +// AWS configs for each region. Consolidates config loading and validation. +func ConfigFromViper() ([]aws.Config, error) { + authConfig := AwsAuthConfig{ + Strategy: viper.GetString("aws-access-strategy"), + AccessKeyID: viper.GetString("aws-access-key-id"), + SecretAccessKey: viper.GetString("aws-secret-access-key"), + ExternalID: viper.GetString("aws-external-id"), + TargetRoleARN: viper.GetString("aws-target-role-arn"), + Profile: viper.GetString("aws-profile"), + AutoConfig: viper.GetBool("auto-config"), + } + if err := viper.UnmarshalKey("aws-regions", &authConfig.Regions); err != nil { + return nil, fmt.Errorf("could not parse aws-regions: %w", err) + } + return CreateAWSConfigs(authConfig) +} + +// isOptInRegionError checks if an error indicates an opt-in region that is not enabled. +// This typically occurs when trying to authenticate with IRSA in a region that hasn't +// been enabled in the AWS account. These errors should not cause source initialization +// to fail - the region should simply be skipped. +func isOptInRegionError(err error) bool { + if err == nil { + return false + } + + // Check for the InvalidIdentityToken error code from STS AssumeRoleWithWebIdentity + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + if apiErr.ErrorCode() == "InvalidIdentityToken" { + // Additional validation: check if it's specifically about OIDC provider + errMsg := err.Error() + if strings.Contains(errMsg, "No OpenIDConnect provider found") { + return true + } + } + } + + return false +} + // wrapRegionError wraps misleading AWS errors with more helpful context func wrapRegionError(err error, region string) error { if err == nil { return nil } - errMsg := err.Error() - - // Check for OIDC-related errors which often indicate disabled opt-in regions - if strings.Contains(errMsg, "No OpenIDConnect provider found") { + // Check for opt-in region errors and provide helpful context + if isOptInRegionError(err) { return fmt.Errorf("%w. This error often occurs when region '%s' is not enabled in the target AWS account", err, region) } @@ -210,16 +252,12 @@ func CreateAWSConfigs(awsAuthConfig AwsAuthConfig) ([]aws.Config, error) { return configs, nil } -// InitializeAwsSourceEngine initializes an Engine with AWS sources, returns the -// engine, and an error if any. The context provided will be used for the rate -// limit buckets and should not be cancelled until the source is shut down. AWS -// configs should be provided for each region that is enabled -func InitializeAwsSourceEngine(ctx context.Context, ec *discovery.EngineConfig, maxRetries int, configs ...aws.Config) (*discovery.Engine, error) { - e, err := discovery.NewEngine(ec) - if err != nil { - return nil, fmt.Errorf("error initializing Engine: %w", err) - } - +// InitializeAwsSourceAdapters adds AWS adapters to an existing engine. This is a single-attempt +// function; retry logic is handled by the caller via Engine.InitialiseAdapters. +// +// The context provided will be used for the rate limit buckets and should not be cancelled until +// the source is shut down. AWS configs should be provided for each region that is enabled. +func InitializeAwsSourceAdapters(ctx context.Context, e *discovery.Engine, configs ...aws.Config) error { // Create a shared cache for all adapters in this source sharedCache := sdpcache.NewCache(ctx) @@ -251,365 +289,361 @@ func InitializeAwsSourceEngine(ctx context.Context, ec *discovery.EngineConfig, return nil }) if len(configs) == 0 { - return nil, errors.New("No configs specified") + return errors.New("no configs specified") } var globalDone atomic.Bool - b := backoff.NewExponentialBackOff() - b.MaxInterval = 30 * time.Second - tick := backoff.NewTicker(b) - - try := 0 - - for { - try++ - if try > maxRetries { - return nil, fmt.Errorf("maximum retries (%d) exceeded", maxRetries) - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case _, ok := <-tick.C: - if !ok { - // If the backoff stops, then we should stop trying to - // initialize and just return the error - return nil, err - } - // Clear any adapters from previous retry attempts to avoid - // duplicate registration errors - e.ClearAdapters() - globalDone.Store(false) - - p := pool.New().WithContext(ctx) - - for _, cfg := range configs { - p.Go(func(ctx context.Context) error { - configCtx, configCancel := context.WithTimeout(ctx, 10*time.Second) - defer configCancel() - - log.WithFields(log.Fields{ - "region": cfg.Region, - }).Info("Initializing AWS source") + // Track regions that are skipped due to not being enabled (opt-in regions) + type skippedRegion struct { + region string + err error + } + var skippedRegions []skippedRegion + var skippedRegionsMu sync.Mutex - // Work out what account we're using. This will be used in item scopes - stsClient := sts.NewFromConfig(cfg) + p := pool.New().WithContext(ctx) - callerID, err := stsClient.GetCallerIdentity(configCtx, &sts.GetCallerIdentityInput{}) - if err != nil { - lf := log.Fields{ - "region": cfg.Region, - } + for _, cfg := range configs { + p.Go(func(ctx context.Context) error { + configCtx, configCancel := context.WithTimeout(ctx, 10*time.Second) + defer configCancel() - // Wrap misleading OIDC errors with helpful region enablement context - wrappedErr := wrapRegionError(err, cfg.Region) + log.WithFields(log.Fields{ + "region": cfg.Region, + }).Info("Initializing AWS source") - log.WithError(wrappedErr).WithFields(lf).Error("Error retrieving account information") - return fmt.Errorf("error getting caller identity for region %v: %w", cfg.Region, wrappedErr) - } + // Work out what account we're using. This will be used in item scopes + stsClient := sts.NewFromConfig(cfg) - // Create shared clients for each API - autoscalingClient := awsautoscaling.NewFromConfig(cfg, func(o *awsautoscaling.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - cloudfrontClient := awscloudfront.NewFromConfig(cfg, func(o *awscloudfront.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - cloudwatchClient := awscloudwatch.NewFromConfig(cfg, func(o *awscloudwatch.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - directconnectClient := awsdirectconnect.NewFromConfig(cfg, func(o *awsdirectconnect.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - dynamodbClient := awsdynamodb.NewFromConfig(cfg, func(o *awsdynamodb.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - ec2Client := awsec2.NewFromConfig(cfg, func(o *awsec2.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - ecsClient := awsecs.NewFromConfig(cfg, func(o *awsecs.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - efsClient := awsefs.NewFromConfig(cfg, func(o *awsefs.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - eksClient := awseks.NewFromConfig(cfg, func(o *awseks.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - elbClient := awselasticloadbalancing.NewFromConfig(cfg, func(o *awselasticloadbalancing.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - elbv2Client := awselasticloadbalancingv2.NewFromConfig(cfg, func(o *awselasticloadbalancingv2.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - lambdaClient := awslambda.NewFromConfig(cfg, func(o *awslambda.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - networkfirewallClient := awsnetworkfirewall.NewFromConfig(cfg, func(o *awsnetworkfirewall.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - rdsClient := awsrds.NewFromConfig(cfg, func(o *awsrds.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - snsClient := awssns.NewFromConfig(cfg, func(o *awssns.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - sqsClient := awssqs.NewFromConfig(cfg, func(o *awssqs.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - route53Client := awsroute53.NewFromConfig(cfg, func(o *awsroute53.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - networkmanagerClient := awsnetworkmanager.NewFromConfig(cfg, func(o *awsnetworkmanager.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - iamClient := awsiam.NewFromConfig(cfg, func(o *awsiam.Options) { - o.RetryMode = aws.RetryModeAdaptive - // Increase this from the default of 3 since IAM as such low rate limits - o.RetryMaxAttempts = 5 - }) - kmsClient := awskms.NewFromConfig(cfg, func(o *awskms.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - apigatewayClient := awsapigateway.NewFromConfig(cfg, func(o *awsapigateway.Options) { - o.RetryMode = aws.RetryModeAdaptive - }) - ssmClient := ssm.NewFromConfig(cfg, func(o *ssm.Options) { - o.RetryMode = aws.RetryModeAdaptive + callerID, err := stsClient.GetCallerIdentity(configCtx, &sts.GetCallerIdentityInput{}) + if err != nil { + lf := log.Fields{ + "region": cfg.Region, + } + + // Check if this is an opt-in region error + if isOptInRegionError(err) { + // This region is not enabled in the account - skip it but don't fail + wrappedErr := wrapRegionError(err, cfg.Region) + skippedRegionsMu.Lock() + skippedRegions = append(skippedRegions, skippedRegion{ + region: cfg.Region, + err: wrappedErr, }) + skippedRegionsMu.Unlock() + log.WithError(wrappedErr).WithFields(lf).Warn("Skipping region - not enabled in account") + return nil // Don't fail the pool for opt-in regions + } + + // Wrap misleading OIDC errors with helpful region enablement context + wrappedErr := wrapRegionError(err, cfg.Region) - configuredAdapters := []discovery.Adapter{ - // EC2 - adapters.NewEC2AddressAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2CapacityReservationFleetAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2CapacityReservationAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2EgressOnlyInternetGatewayAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2IamInstanceProfileAssociationAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2ImageAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2InstanceEventWindowAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2InstanceAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2InstanceStatusAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2InternetGatewayAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2KeyPairAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2LaunchTemplateAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2LaunchTemplateVersionAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2NatGatewayAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2NetworkAclAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2NetworkInterfacePermissionAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2NetworkInterfaceAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2PlacementGroupAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2ReservedInstanceAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2RouteTableAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2SecurityGroupRuleAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2SecurityGroupAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2SnapshotAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2SubnetAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2VolumeAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2VolumeStatusAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2VpcEndpointAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2VpcPeeringConnectionAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEC2VpcAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), - - // EFS (I'm assuming it shares its rate limit with EC2)) - adapters.NewEFSAccessPointAdapter(efsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEFSBackupPolicyAdapter(efsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEFSFileSystemAdapter(efsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEFSMountTargetAdapter(efsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEFSReplicationConfigurationAdapter(efsClient, *callerID.Account, cfg.Region, sharedCache), - - // EKS - adapters.NewEKSAddonAdapter(eksClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEKSClusterAdapter(eksClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEKSFargateProfileAdapter(eksClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewEKSNodegroupAdapter(eksClient, *callerID.Account, cfg.Region, sharedCache), - - // Route 53 - adapters.NewRoute53HealthCheckAdapter(route53Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewRoute53HostedZoneAdapter(route53Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewRoute53ResourceRecordSetAdapter(route53Client, *callerID.Account, cfg.Region, sharedCache), - - // Cloudwatch - adapters.NewCloudwatchAlarmAdapter(cloudwatchClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewCloudwatchInstanceMetricAdapter(cloudwatchClient, *callerID.Account, cfg.Region, sharedCache), - - // Lambda - adapters.NewLambdaFunctionAdapter(lambdaClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewLambdaLayerAdapter(lambdaClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewLambdaLayerVersionAdapter(lambdaClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewLambdaEventSourceMappingAdapter(lambdaClient, *callerID.Account, cfg.Region, sharedCache), - - // ECS - adapters.NewECSCapacityProviderAdapter(ecsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewECSClusterAdapter(ecsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewECSContainerInstanceAdapter(ecsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewECSServiceAdapter(ecsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewECSTaskDefinitionAdapter(ecsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewECSTaskAdapter(ecsClient, *callerID.Account, cfg.Region, sharedCache), - - // DynamoDB - adapters.NewDynamoDBBackupAdapter(dynamodbClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDynamoDBTableAdapter(dynamodbClient, *callerID.Account, cfg.Region, sharedCache), - - // RDS - adapters.NewRDSDBClusterParameterGroupAdapter(rdsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewRDSDBClusterAdapter(rdsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewRDSDBInstanceAdapter(rdsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewRDSDBParameterGroupAdapter(rdsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewRDSDBSubnetGroupAdapter(rdsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewRDSOptionGroupAdapter(rdsClient, *callerID.Account, cfg.Region, sharedCache), - - // AutoScaling - adapters.NewAutoScalingGroupAdapter(autoscalingClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewAutoScalingPolicyAdapter(autoscalingClient, *callerID.Account, cfg.Region, sharedCache), - - // ELB - adapters.NewELBInstanceHealthAdapter(elbClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewELBLoadBalancerAdapter(elbClient, *callerID.Account, cfg.Region, sharedCache), - - // ELBv2 - adapters.NewELBv2ListenerAdapter(elbv2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewELBv2LoadBalancerAdapter(elbv2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewELBv2RuleAdapter(elbv2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewELBv2TargetGroupAdapter(elbv2Client, *callerID.Account, cfg.Region, sharedCache), - adapters.NewELBv2TargetHealthAdapter(elbv2Client, *callerID.Account, cfg.Region, sharedCache), - - // Network Firewall - adapters.NewNetworkFirewallFirewallAdapter(networkfirewallClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkFirewallFirewallPolicyAdapter(networkfirewallClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkFirewallRuleGroupAdapter(networkfirewallClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkFirewallTLSInspectionConfigurationAdapter(networkfirewallClient, *callerID.Account, cfg.Region, sharedCache), - - // Direct Connect - adapters.NewDirectConnectGatewayAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDirectConnectGatewayAssociationAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDirectConnectGatewayAssociationProposalAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDirectConnectConnectionAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDirectConnectGatewayAttachmentAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDirectConnectVirtualInterfaceAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDirectConnectVirtualGatewayAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDirectConnectCustomerMetadataAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDirectConnectLagAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDirectConnectLocationAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDirectConnectHostedConnectionAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDirectConnectInterconnectAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewDirectConnectRouterConfigurationAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), - - // Network Manager - adapters.NewNetworkManagerConnectAttachmentAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkManagerConnectPeerAssociationAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkManagerConnectPeerAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkManagerCoreNetworkPolicyAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkManagerCoreNetworkAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkManagerNetworkResourceRelationshipsAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkManagerSiteToSiteVpnAttachmentAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkManagerTransitGatewayConnectPeerAssociationAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkManagerTransitGatewayPeeringAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkManagerTransitGatewayRegistrationAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkManagerTransitGatewayRouteTableAttachmentAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewNetworkManagerVPCAttachmentAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), - - // SQS - adapters.NewSQSQueueAdapter(sqsClient, *callerID.Account, cfg.Region, sharedCache), - - // SNS - adapters.NewSNSSubscriptionAdapter(snsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewSNSTopicAdapter(snsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewSNSPlatformApplicationAdapter(snsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewSNSEndpointAdapter(snsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewSNSDataProtectionPolicyAdapter(snsClient, *callerID.Account, cfg.Region, sharedCache), - - // KMS - adapters.NewKMSKeyAdapter(kmsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewKMSCustomKeyStoreAdapter(kmsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewKMSAliasAdapter(kmsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewKMSGrantAdapter(kmsClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewKMSKeyPolicyAdapter(kmsClient, *callerID.Account, cfg.Region, sharedCache), - - // ApiGateway - adapters.NewAPIGatewayRestApiAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewAPIGatewayResourceAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewAPIGatewayDomainNameAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewAPIGatewayMethodAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewAPIGatewayMethodResponseAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewAPIGatewayIntegrationAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewAPIGatewayApiKeyAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewAPIGatewayAuthorizerAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewAPIGatewayDeploymentAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewAPIGatewayStageAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), - adapters.NewAPIGatewayModelAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), - - // SSM - adapters.NewSSMParameterAdapter(ssmClient, *callerID.Account, cfg.Region, sharedCache), - } - - err = e.AddAdapters(configuredAdapters...) - if err != nil { - return err - } - - // Add "global" sources (those that aren't tied to a region, like - // cloudfront). but only do this once for the first region. For - // these APIs it doesn't matter which region we call them from, we - // get global results - if globalDone.CompareAndSwap(false, true) { - globalAdapters := []discovery.Adapter{ - // Cloudfront - adapters.NewCloudfrontCachePolicyAdapter(cloudfrontClient, *callerID.Account, sharedCache), - adapters.NewCloudfrontContinuousDeploymentPolicyAdapter(cloudfrontClient, *callerID.Account, sharedCache), - adapters.NewCloudfrontDistributionAdapter(cloudfrontClient, *callerID.Account, sharedCache), - adapters.NewCloudfrontCloudfrontFunctionAdapter(cloudfrontClient, *callerID.Account, sharedCache), - adapters.NewCloudfrontKeyGroupAdapter(cloudfrontClient, *callerID.Account, sharedCache), - adapters.NewCloudfrontOriginAccessControlAdapter(cloudfrontClient, *callerID.Account, sharedCache), - adapters.NewCloudfrontOriginRequestPolicyAdapter(cloudfrontClient, *callerID.Account, sharedCache), - adapters.NewCloudfrontResponseHeadersPolicyAdapter(cloudfrontClient, *callerID.Account, sharedCache), - adapters.NewCloudfrontRealtimeLogConfigsAdapter(cloudfrontClient, *callerID.Account, sharedCache), - adapters.NewCloudfrontStreamingDistributionAdapter(cloudfrontClient, *callerID.Account, sharedCache), - - // S3 - adapters.NewS3Adapter(cfg, *callerID.Account, sharedCache), - - // Networkmanager - adapters.NewNetworkManagerGlobalNetworkAdapter(networkmanagerClient, *callerID.Account, sharedCache), - adapters.NewNetworkManagerSiteAdapter(networkmanagerClient, *callerID.Account, sharedCache), - adapters.NewNetworkManagerLinkAdapter(networkmanagerClient, *callerID.Account, sharedCache), - adapters.NewNetworkManagerDeviceAdapter(networkmanagerClient, *callerID.Account, sharedCache), - adapters.NewNetworkManagerLinkAssociationAdapter(networkmanagerClient, *callerID.Account, sharedCache), - adapters.NewNetworkManagerConnectionAdapter(networkmanagerClient, *callerID.Account, sharedCache), - - // IAM - adapters.NewIAMPolicyAdapter(iamClient, *callerID.Account, sharedCache), - adapters.NewIAMGroupAdapter(iamClient, *callerID.Account, sharedCache), - adapters.NewIAMInstanceProfileAdapter(iamClient, *callerID.Account, sharedCache), - adapters.NewIAMRoleAdapter(iamClient, *callerID.Account, sharedCache), - adapters.NewIAMUserAdapter(iamClient, *callerID.Account, sharedCache), - } - - err = e.AddAdapters(globalAdapters...) - if err != nil { - return err - } - } - return nil - }) + log.WithError(wrappedErr).WithFields(lf).Error("Error retrieving account information") + return fmt.Errorf("error getting caller identity for region %v: %w", cfg.Region, wrappedErr) } - err = p.Wait() - brokenHeart := e.SendHeartbeat(ctx, nil) // Send heartbeat with any errors - if brokenHeart != nil { - log.WithError(brokenHeart).Error("Error sending heartbeat") + // Create shared clients for each API + autoscalingClient := awsautoscaling.NewFromConfig(cfg, func(o *awsautoscaling.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + cloudfrontClient := awscloudfront.NewFromConfig(cfg, func(o *awscloudfront.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + cloudwatchClient := awscloudwatch.NewFromConfig(cfg, func(o *awscloudwatch.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + directconnectClient := awsdirectconnect.NewFromConfig(cfg, func(o *awsdirectconnect.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + dynamodbClient := awsdynamodb.NewFromConfig(cfg, func(o *awsdynamodb.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + ec2Client := awsec2.NewFromConfig(cfg, func(o *awsec2.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + ecsClient := awsecs.NewFromConfig(cfg, func(o *awsecs.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + efsClient := awsefs.NewFromConfig(cfg, func(o *awsefs.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + eksClient := awseks.NewFromConfig(cfg, func(o *awseks.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + elbClient := awselasticloadbalancing.NewFromConfig(cfg, func(o *awselasticloadbalancing.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + elbv2Client := awselasticloadbalancingv2.NewFromConfig(cfg, func(o *awselasticloadbalancingv2.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + lambdaClient := awslambda.NewFromConfig(cfg, func(o *awslambda.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + networkfirewallClient := awsnetworkfirewall.NewFromConfig(cfg, func(o *awsnetworkfirewall.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + rdsClient := awsrds.NewFromConfig(cfg, func(o *awsrds.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + snsClient := awssns.NewFromConfig(cfg, func(o *awssns.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + sqsClient := awssqs.NewFromConfig(cfg, func(o *awssqs.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + route53Client := awsroute53.NewFromConfig(cfg, func(o *awsroute53.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + networkmanagerClient := awsnetworkmanager.NewFromConfig(cfg, func(o *awsnetworkmanager.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + iamClient := awsiam.NewFromConfig(cfg, func(o *awsiam.Options) { + o.RetryMode = aws.RetryModeAdaptive + // Increase this from the default of 3 since IAM as such low rate limits + o.RetryMaxAttempts = 5 + }) + kmsClient := awskms.NewFromConfig(cfg, func(o *awskms.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + apigatewayClient := awsapigateway.NewFromConfig(cfg, func(o *awsapigateway.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + ssmClient := ssm.NewFromConfig(cfg, func(o *ssm.Options) { + o.RetryMode = aws.RetryModeAdaptive + }) + + configuredAdapters := []discovery.Adapter{ + // EC2 + adapters.NewEC2AddressAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2CapacityReservationFleetAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2CapacityReservationAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2EgressOnlyInternetGatewayAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2IamInstanceProfileAssociationAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2ImageAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2InstanceEventWindowAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2InstanceAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2InstanceStatusAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2InternetGatewayAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2KeyPairAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2LaunchTemplateAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2LaunchTemplateVersionAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2NatGatewayAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2NetworkAclAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2NetworkInterfacePermissionAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2NetworkInterfaceAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2PlacementGroupAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2ReservedInstanceAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2RouteTableAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2SecurityGroupRuleAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2SecurityGroupAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2SnapshotAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2SubnetAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2VolumeAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2VolumeStatusAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2VpcEndpointAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2VpcPeeringConnectionAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEC2VpcAdapter(ec2Client, *callerID.Account, cfg.Region, sharedCache), + + // EFS (I'm assuming it shares its rate limit with EC2)) + adapters.NewEFSAccessPointAdapter(efsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEFSBackupPolicyAdapter(efsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEFSFileSystemAdapter(efsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEFSMountTargetAdapter(efsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEFSReplicationConfigurationAdapter(efsClient, *callerID.Account, cfg.Region, sharedCache), + + // EKS + adapters.NewEKSAddonAdapter(eksClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEKSClusterAdapter(eksClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEKSFargateProfileAdapter(eksClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewEKSNodegroupAdapter(eksClient, *callerID.Account, cfg.Region, sharedCache), + + // Route 53 + adapters.NewRoute53HealthCheckAdapter(route53Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewRoute53HostedZoneAdapter(route53Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewRoute53ResourceRecordSetAdapter(route53Client, *callerID.Account, cfg.Region, sharedCache), + + // Cloudwatch + adapters.NewCloudwatchAlarmAdapter(cloudwatchClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewCloudwatchInstanceMetricAdapter(cloudwatchClient, *callerID.Account, cfg.Region, sharedCache), + + // Lambda + adapters.NewLambdaFunctionAdapter(lambdaClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewLambdaLayerAdapter(lambdaClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewLambdaLayerVersionAdapter(lambdaClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewLambdaEventSourceMappingAdapter(lambdaClient, *callerID.Account, cfg.Region, sharedCache), + + // ECS + adapters.NewECSCapacityProviderAdapter(ecsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewECSClusterAdapter(ecsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewECSContainerInstanceAdapter(ecsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewECSServiceAdapter(ecsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewECSTaskDefinitionAdapter(ecsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewECSTaskAdapter(ecsClient, *callerID.Account, cfg.Region, sharedCache), + + // DynamoDB + adapters.NewDynamoDBBackupAdapter(dynamodbClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDynamoDBTableAdapter(dynamodbClient, *callerID.Account, cfg.Region, sharedCache), + + // RDS + adapters.NewRDSDBClusterParameterGroupAdapter(rdsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewRDSDBClusterAdapter(rdsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewRDSDBInstanceAdapter(rdsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewRDSDBParameterGroupAdapter(rdsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewRDSDBSubnetGroupAdapter(rdsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewRDSOptionGroupAdapter(rdsClient, *callerID.Account, cfg.Region, sharedCache), + + // AutoScaling + adapters.NewAutoScalingGroupAdapter(autoscalingClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewAutoScalingPolicyAdapter(autoscalingClient, *callerID.Account, cfg.Region, sharedCache), + + // ELB + adapters.NewELBInstanceHealthAdapter(elbClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewELBLoadBalancerAdapter(elbClient, *callerID.Account, cfg.Region, sharedCache), + + // ELBv2 + adapters.NewELBv2ListenerAdapter(elbv2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewELBv2LoadBalancerAdapter(elbv2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewELBv2RuleAdapter(elbv2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewELBv2TargetGroupAdapter(elbv2Client, *callerID.Account, cfg.Region, sharedCache), + adapters.NewELBv2TargetHealthAdapter(elbv2Client, *callerID.Account, cfg.Region, sharedCache), + + // Network Firewall + adapters.NewNetworkFirewallFirewallAdapter(networkfirewallClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkFirewallFirewallPolicyAdapter(networkfirewallClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkFirewallRuleGroupAdapter(networkfirewallClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkFirewallTLSInspectionConfigurationAdapter(networkfirewallClient, *callerID.Account, cfg.Region, sharedCache), + + // Direct Connect + adapters.NewDirectConnectGatewayAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDirectConnectGatewayAssociationAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDirectConnectGatewayAssociationProposalAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDirectConnectConnectionAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDirectConnectGatewayAttachmentAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDirectConnectVirtualInterfaceAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDirectConnectVirtualGatewayAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDirectConnectCustomerMetadataAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDirectConnectLagAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDirectConnectLocationAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDirectConnectHostedConnectionAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDirectConnectInterconnectAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewDirectConnectRouterConfigurationAdapter(directconnectClient, *callerID.Account, cfg.Region, sharedCache), + + // Network Manager + adapters.NewNetworkManagerConnectAttachmentAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkManagerConnectPeerAssociationAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkManagerConnectPeerAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkManagerCoreNetworkPolicyAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkManagerCoreNetworkAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkManagerNetworkResourceRelationshipsAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkManagerSiteToSiteVpnAttachmentAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkManagerTransitGatewayConnectPeerAssociationAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkManagerTransitGatewayPeeringAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkManagerTransitGatewayRegistrationAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkManagerTransitGatewayRouteTableAttachmentAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewNetworkManagerVPCAttachmentAdapter(networkmanagerClient, *callerID.Account, cfg.Region, sharedCache), + + // SQS + adapters.NewSQSQueueAdapter(sqsClient, *callerID.Account, cfg.Region, sharedCache), + + // SNS + adapters.NewSNSSubscriptionAdapter(snsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewSNSTopicAdapter(snsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewSNSPlatformApplicationAdapter(snsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewSNSEndpointAdapter(snsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewSNSDataProtectionPolicyAdapter(snsClient, *callerID.Account, cfg.Region, sharedCache), + + // KMS + adapters.NewKMSKeyAdapter(kmsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewKMSCustomKeyStoreAdapter(kmsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewKMSAliasAdapter(kmsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewKMSGrantAdapter(kmsClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewKMSKeyPolicyAdapter(kmsClient, *callerID.Account, cfg.Region, sharedCache), + + // ApiGateway + adapters.NewAPIGatewayRestApiAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewAPIGatewayResourceAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewAPIGatewayDomainNameAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewAPIGatewayMethodAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewAPIGatewayMethodResponseAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewAPIGatewayIntegrationAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewAPIGatewayApiKeyAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewAPIGatewayAuthorizerAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewAPIGatewayDeploymentAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewAPIGatewayStageAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), + adapters.NewAPIGatewayModelAdapter(apigatewayClient, *callerID.Account, cfg.Region, sharedCache), + + // SSM + adapters.NewSSMParameterAdapter(ssmClient, *callerID.Account, cfg.Region, sharedCache), } + err = e.AddAdapters(configuredAdapters...) if err != nil { - log.WithError(err).Debug("Error initializing sources") - } else { - log.Debug("Sources initialized") - // Start sending heartbeats after adapters are successfully added - // This ensures the first heartbeat has adapters available for readiness checks - e.StartSendingHeartbeats(ctx) - // If there is no error then return the engine - return e, nil + return err } + + // Add "global" sources (those that aren't tied to a region, like + // cloudfront). but only do this once for the first region. For + // these APIs it doesn't matter which region we call them from, we + // get global results + if globalDone.CompareAndSwap(false, true) { + globalAdapters := []discovery.Adapter{ + // Cloudfront + adapters.NewCloudfrontCachePolicyAdapter(cloudfrontClient, *callerID.Account, sharedCache), + adapters.NewCloudfrontContinuousDeploymentPolicyAdapter(cloudfrontClient, *callerID.Account, sharedCache), + adapters.NewCloudfrontDistributionAdapter(cloudfrontClient, *callerID.Account, sharedCache), + adapters.NewCloudfrontCloudfrontFunctionAdapter(cloudfrontClient, *callerID.Account, sharedCache), + adapters.NewCloudfrontKeyGroupAdapter(cloudfrontClient, *callerID.Account, sharedCache), + adapters.NewCloudfrontOriginAccessControlAdapter(cloudfrontClient, *callerID.Account, sharedCache), + adapters.NewCloudfrontOriginRequestPolicyAdapter(cloudfrontClient, *callerID.Account, sharedCache), + adapters.NewCloudfrontResponseHeadersPolicyAdapter(cloudfrontClient, *callerID.Account, sharedCache), + adapters.NewCloudfrontRealtimeLogConfigsAdapter(cloudfrontClient, *callerID.Account, sharedCache), + adapters.NewCloudfrontStreamingDistributionAdapter(cloudfrontClient, *callerID.Account, sharedCache), + + // S3 + adapters.NewS3Adapter(cfg, *callerID.Account, sharedCache), + + // Networkmanager + adapters.NewNetworkManagerGlobalNetworkAdapter(networkmanagerClient, *callerID.Account, sharedCache), + adapters.NewNetworkManagerSiteAdapter(networkmanagerClient, *callerID.Account, sharedCache), + adapters.NewNetworkManagerLinkAdapter(networkmanagerClient, *callerID.Account, sharedCache), + adapters.NewNetworkManagerDeviceAdapter(networkmanagerClient, *callerID.Account, sharedCache), + adapters.NewNetworkManagerLinkAssociationAdapter(networkmanagerClient, *callerID.Account, sharedCache), + adapters.NewNetworkManagerConnectionAdapter(networkmanagerClient, *callerID.Account, sharedCache), + + // IAM + adapters.NewIAMPolicyAdapter(iamClient, *callerID.Account, sharedCache), + adapters.NewIAMGroupAdapter(iamClient, *callerID.Account, sharedCache), + adapters.NewIAMInstanceProfileAdapter(iamClient, *callerID.Account, sharedCache), + adapters.NewIAMRoleAdapter(iamClient, *callerID.Account, sharedCache), + adapters.NewIAMUserAdapter(iamClient, *callerID.Account, sharedCache), + } + + err = e.AddAdapters(globalAdapters...) + if err != nil { + return err + } + } + return nil + }) + } + + if err := p.Wait(); err != nil { + return err + } + + // Log summary of skipped regions if any + if len(skippedRegions) > 0 { + skippedRegionNames := make([]string, 0, len(skippedRegions)) + for _, sr := range skippedRegions { + skippedRegionNames = append(skippedRegionNames, sr.region) } + log.WithFields(log.Fields{ + "skipped_regions": skippedRegionNames, + "count": len(skippedRegions), + }).Warn("Some regions were skipped because they are not enabled in the AWS account. The source will operate normally with the remaining regions.") } + + log.Debug("Sources initialized") + return nil } diff --git a/aws-source/proc/proc_test.go b/aws-source/proc/proc_test.go index d99d5951..630f050c 100644 --- a/aws-source/proc/proc_test.go +++ b/aws-source/proc/proc_test.go @@ -3,9 +3,11 @@ package proc import ( "context" "errors" + "fmt" "strings" "testing" + "github.com/aws/smithy-go" "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" "github.com/stretchr/testify/assert" @@ -99,6 +101,93 @@ func TestInitializeAwsSourceEngine_RetryClearsAdapters(t *testing.T) { assert.Contains(t, scopes, "123456789012.us-east-1", "Scope should be present after re-adding") } +// mockAPIError implements smithy.APIError for testing +type mockAPIError struct { + code string + message string +} + +func (m *mockAPIError) Error() string { + return m.message +} + +func (m *mockAPIError) ErrorCode() string { + return m.code +} + +func (m *mockAPIError) ErrorMessage() string { + return m.message +} + +func (m *mockAPIError) ErrorFault() smithy.ErrorFault { + return smithy.FaultUnknown +} + +func TestIsOptInRegionError(t *testing.T) { + tests := []struct { + name string + err error + expectedResult bool + }{ + { + name: "nil error returns false", + err: nil, + expectedResult: false, + }, + { + name: "InvalidIdentityToken with OIDC message returns true", + err: &mockAPIError{ + code: "InvalidIdentityToken", + message: "InvalidIdentityToken: No OpenIDConnect provider found in your account for https://oidc.eks.eu-west-2.amazonaws.com/id/ABC123", + }, + expectedResult: true, + }, + { + name: "wrapped InvalidIdentityToken with OIDC message returns true", + err: fmt.Errorf("operation error STS: AssumeRoleWithWebIdentity: %w", &mockAPIError{ + code: "InvalidIdentityToken", + message: "No OpenIDConnect provider found in your account", + }), + expectedResult: true, + }, + { + name: "InvalidIdentityToken without OIDC message returns false", + err: &mockAPIError{ + code: "InvalidIdentityToken", + message: "Invalid identity token for some other reason", + }, + expectedResult: false, + }, + { + name: "different error code returns false", + err: &mockAPIError{ + code: "AccessDenied", + message: "Access denied", + }, + expectedResult: false, + }, + { + name: "non-AWS error returns false", + err: errors.New("some random error"), + expectedResult: false, + }, + { + name: "error with OIDC text but not API error returns false", + err: errors.New("No OpenIDConnect provider found"), + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isOptInRegionError(tt.err) + if result != tt.expectedResult { + t.Errorf("isOptInRegionError() = %v, want %v for error: %v", result, tt.expectedResult, tt.err) + } + }) + } +} + func TestWrapRegionError(t *testing.T) { tests := []struct { name string @@ -115,22 +204,31 @@ func TestWrapRegionError(t *testing.T) { expectedText: "", }, { - name: "OIDC provider error gets wrapped", - err: errors.New("InvalidIdentityToken: No OpenIDConnect provider found in your account"), + name: "opt-in region error gets wrapped", + err: &mockAPIError{ + code: "InvalidIdentityToken", + message: "No OpenIDConnect provider found in your account", + }, region: "eu-central-2", shouldWrap: true, expectedText: "region 'eu-central-2' is not enabled", }, { - name: "InvalidIdentityToken error without OIDC text not wrapped", - err: errors.New("InvalidIdentityToken: some other message"), + name: "wrapped opt-in region error gets additional context", + err: fmt.Errorf("operation error STS: AssumeRoleWithWebIdentity: %w", &mockAPIError{ + code: "InvalidIdentityToken", + message: "No OpenIDConnect provider found in your account", + }), region: "ap-south-2", - shouldWrap: false, - expectedText: "", + shouldWrap: true, + expectedText: "region 'ap-south-2' is not enabled", }, { - name: "AssumeRoleWithWebIdentity exceeded attempts not wrapped", - err: errors.New("operation error STS: AssumeRoleWithWebIdentity, exceeded maximum number of attempts"), + name: "InvalidIdentityToken without OIDC text not wrapped", + err: &mockAPIError{ + code: "InvalidIdentityToken", + message: "some other message", + }, region: "me-central-1", shouldWrap: false, expectedText: "", @@ -166,6 +264,10 @@ func TestWrapRegionError(t *testing.T) { if !strings.Contains(resultMsg, tt.expectedText) { t.Errorf("expected wrapped error to contain '%s', got: %v", tt.expectedText, resultMsg) } + // Verify the original error is preserved (wrapped with %w) + if !errors.Is(result, tt.err) { + t.Errorf("expected wrapped error to contain original error") + } } else { if strings.Contains(resultMsg, "region") && strings.Contains(resultMsg, "not enabled") { t.Errorf("expected error not to be wrapped, but it was: %v", resultMsg) diff --git a/cmd/changes_get_change.go b/cmd/changes_get_change.go index 8254ae7d..4237e4c6 100644 --- a/cmd/changes_get_change.go +++ b/cmd/changes_get_change.go @@ -95,7 +95,7 @@ fetch: } } // display the running entry - runningEntry, status, err := sdp.TimelineFindInProgressEntry(timeLine.GetEntries()) + runningEntry, contentDescription, status, err := sdp.TimelineFindInProgressEntry(timeLine.GetEntries()) if err != nil { return loggedError{ err: err, @@ -107,6 +107,7 @@ fetch: log.WithContext(ctx).WithFields(log.Fields{ "status": status.String(), "running": runningEntry, + "content": contentDescription, }).Info("Waiting for change analysis to complete") // retry time.Sleep(3 * time.Second) @@ -148,7 +149,7 @@ fetch: } log.WithContext(ctx).WithFields(log.Fields{ "ovm.change.uuid": changeUuid.String(), - }).Info("found change") + }).Debug("found change") fmt.Println(changeRes.Msg.GetChange()) diff --git a/cmd/changes_get_signals.go b/cmd/changes_get_signals.go index b8014e14..52d7420c 100644 --- a/cmd/changes_get_signals.go +++ b/cmd/changes_get_signals.go @@ -78,7 +78,7 @@ fetch: } } // display the running entry - runningEntry, status, err := sdp.TimelineFindInProgressEntry(timeLine.GetEntries()) + runningEntry, contentDescription, status, err := sdp.TimelineFindInProgressEntry(timeLine.GetEntries()) if err != nil { return loggedError{ err: err, @@ -90,6 +90,7 @@ fetch: log.WithContext(ctx).WithFields(log.Fields{ "status": status.String(), "running": runningEntry, + "content": contentDescription, }).Info("Waiting for change analysis to complete") // retry time.Sleep(3 * time.Second) @@ -128,7 +129,7 @@ fetch: } log.WithContext(ctx).WithFields(log.Fields{ "ovm.change.uuid": changeUuid.String(), - }).Info("found change signals") + }).Debug("found change signals") fmt.Println(signalsRes.Msg.GetSignals()) diff --git a/cmd/changes_list_changes.go b/cmd/changes_list_changes.go index a7b6c91e..073ff8bc 100644 --- a/cmd/changes_list_changes.go +++ b/cmd/changes_list_changes.go @@ -51,7 +51,7 @@ func ListChanges(cmd *cobra.Command, args []string) error { "change-status": change.GetMetadata().GetStatus().String(), "change-name": change.GetProperties().GetTitle(), "change-description": change.GetProperties().GetDescription(), - }).Info("found change") + }).Debug("found change") b, err := json.MarshalIndent(change.ToMap(), "", " ") if err != nil { diff --git a/cmd/changes_start_change.go b/cmd/changes_start_change.go index 12980bde..7d915871 100644 --- a/cmd/changes_start_change.go +++ b/cmd/changes_start_change.go @@ -75,7 +75,7 @@ fetch: } } // display the running entry - runningEntry, status, err := sdp.TimelineFindInProgressEntry(timeLine.GetEntries()) + runningEntry, contentDescription, status, err := sdp.TimelineFindInProgressEntry(timeLine.GetEntries()) if err != nil { return loggedError{ err: err, @@ -87,6 +87,7 @@ fetch: log.WithContext(ctx).WithFields(log.Fields{ "status": status.String(), "running": runningEntry, + "content": contentDescription, }).Info("Waiting for blast radius to be calculated") // retry time.Sleep(3 * time.Second) diff --git a/cmd/explore.go b/cmd/explore.go index b2578a7a..ed2e3550 100644 --- a/cmd/explore.go +++ b/cmd/explore.go @@ -127,14 +127,15 @@ func StartLocalSources(ctx context.Context, oi sdp.OvermindInstance, token *oaut MaxParallelExecutions: 2_000, HeartbeatOptions: heartbeatOptions(oi, token), } - stdlibEngine, err := stdlibSource.InitializeEngine( - ctx, - &ec, - true, - ) + stdlibEngine, err := discovery.NewEngine(&ec) if err != nil { - stdlibSpinner.Fail("Failed to initialize stdlib source engine") - return nil, fmt.Errorf("failed to initialize stdlib source engine: %w", err) + stdlibSpinner.Fail("Failed to create stdlib source engine") + return nil, fmt.Errorf("failed to create stdlib source engine: %w", err) + } + err = stdlibSource.InitializeAdapters(ctx, stdlibEngine, true) + if err != nil { + stdlibSpinner.Fail("Failed to initialize stdlib source adapters") + return nil, fmt.Errorf("failed to initialize stdlib source adapters: %w", err) } // todo: pass in context with timeout to abort timely and allow Ctrl-C to work err = stdlibEngine.Start(ctx) @@ -142,6 +143,7 @@ func StartLocalSources(ctx context.Context, oi sdp.OvermindInstance, token *oaut stdlibSpinner.Fail("Failed to start stdlib source engine") return nil, fmt.Errorf("failed to start stdlib source engine: %w", err) } + stdlibEngine.StartSendingHeartbeats(ctx) stdlibSpinner.Success("Stdlib source engine started") return []*discovery.Engine{stdlibEngine}, nil }) @@ -213,20 +215,25 @@ func StartLocalSources(ctx context.Context, oi sdp.OvermindInstance, token *oaut NATSOptions: &natsOpts, HeartbeatOptions: heartbeatOptions(oi, token), } - awsEngine, err := proc.InitializeAwsSourceEngine( + awsEngine, err := discovery.NewEngine(&ec) + if err != nil { + awsSpinner.Fail("Failed to create AWS source engine") + return nil, fmt.Errorf("failed to create AWS source engine: %w", err) + } + + err = proc.InitializeAwsSourceAdapters( ctx, - &ec, - 1, // Don't retry as we want the user to get notified immediately + awsEngine, configs..., ) if err != nil { if os.Getenv("AWS_PROFILE") == "" { // look for the AWS_PROFILE env var and suggest setting it - awsSpinner.Fail("Failed to initialize AWS source engine. Consider setting AWS_PROFILE to use the default AWS CLI profile.") + awsSpinner.Fail("Failed to initialize AWS source adapters. Consider setting AWS_PROFILE to use the default AWS CLI profile.") } else { - awsSpinner.Fail("Failed to initialize AWS source engine") + awsSpinner.Fail("Failed to initialize AWS source adapters") } - return nil, fmt.Errorf("failed to initialize AWS source engine: %w", err) + return nil, fmt.Errorf("failed to initialize AWS source adapters: %w", err) } err = awsEngine.Start(ctx) @@ -234,6 +241,7 @@ func StartLocalSources(ctx context.Context, oi sdp.OvermindInstance, token *oaut awsSpinner.Fail("Failed to start AWS source engine") return nil, fmt.Errorf("failed to start AWS source engine: %w", err) } + awsEngine.StartSendingHeartbeats(ctx) awsSpinner.Success("AWS source engine started") foundCloudProvider = true @@ -316,13 +324,23 @@ func StartLocalSources(ctx context.Context, oi sdp.OvermindInstance, token *oaut HeartbeatOptions: heartbeatOptions(oi, token), } - gcpEngine, err := gcpproc.Initialize(ctx, &ec, gcpConfig) + gcpEngine, err := discovery.NewEngine(&ec) + if err != nil { + if gcpConfig == nil { + statusArea.Println(fmt.Sprintf("Failed to create GCP source engine with default credentials: %s", err.Error())) + } else { + statusArea.Println(fmt.Sprintf("Failed to create GCP source engine for project %s: %s", gcpConfig.ProjectID, err.Error())) + } + continue // Skip this engine but continue with others + } + + err = gcpproc.InitializeAdapters(ctx, gcpEngine, gcpConfig) if err != nil { if gcpConfig == nil { // Default config failed - statusArea.Println(fmt.Sprintf("Failed to initialize GCP source with default credentials: %s", err.Error())) + statusArea.Println(fmt.Sprintf("Failed to initialize GCP source adapters with default credentials: %s", err.Error())) } else { - statusArea.Println(fmt.Sprintf("Failed to initialize GCP source for project %s: %s", gcpConfig.ProjectID, err.Error())) + statusArea.Println(fmt.Sprintf("Failed to initialize GCP source adapters for project %s: %s", gcpConfig.ProjectID, err.Error())) } continue // Skip this engine but continue with others } @@ -336,6 +354,7 @@ func StartLocalSources(ctx context.Context, oi sdp.OvermindInstance, token *oaut } continue // Skip this engine but continue with others } + gcpEngine.StartSendingHeartbeats(ctx) gcpEngines = append(gcpEngines, gcpEngine) } @@ -460,9 +479,15 @@ func StartLocalSources(ctx context.Context, oi sdp.OvermindInstance, token *oaut HeartbeatOptions: heartbeatOptions(oi, token), } - azureEngine, err := azureproc.Initialize(ctx, &ec, azureConfig) + azureEngine, err := discovery.NewEngine(&ec) + if err != nil { + statusArea.Println(fmt.Sprintf("Failed to create Azure source engine for subscription %s: %s", azureConfig.SubscriptionID, err.Error())) + continue // Skip this engine but continue with others + } + + err = azureproc.InitializeAdapters(ctx, azureEngine, azureConfig) if err != nil { - statusArea.Println(fmt.Sprintf("Failed to initialize Azure source for subscription %s: %s", azureConfig.SubscriptionID, err.Error())) + statusArea.Println(fmt.Sprintf("Failed to initialize Azure source adapters for subscription %s: %s", azureConfig.SubscriptionID, err.Error())) continue // Skip this engine but continue with others } @@ -471,6 +496,7 @@ func StartLocalSources(ctx context.Context, oi sdp.OvermindInstance, token *oaut statusArea.Println(fmt.Sprintf("Failed to start Azure source for subscription %s: %s", azureConfig.SubscriptionID, err.Error())) continue // Skip this engine but continue with others } + azureEngine.StartSendingHeartbeats(ctx) azureEngines = append(azureEngines, azureEngine) } diff --git a/cmd/root.go b/cmd/root.go index c51f5ede..a262b11b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -111,6 +111,15 @@ func Execute() { formatter := new(log.TextFormatter) formatter.DisableTimestamp = true log.SetFormatter(formatter) + log.SetOutput(os.Stderr) + + // Configure pterm to output to stderr instead of stdout + // This ensures status messages don't interfere with piped output + pterm.SetDefaultOutput(os.Stderr) + pterm.Info.Writer = os.Stderr + pterm.Success.Writer = os.Stderr + pterm.Warning.Writer = os.Stderr + pterm.Error.Writer = os.Stderr // create a sub-scope to run deferred cleanups before shutting down the tracer err := func() error { diff --git a/cmd/terraform_plan.go b/cmd/terraform_plan.go index 209a3632..fae484b7 100644 --- a/cmd/terraform_plan.go +++ b/cmd/terraform_plan.go @@ -152,11 +152,12 @@ func TerraformPlanImpl(ctx context.Context, cmd *cobra.Command, oi sdp.OvermindI removingSecretsSpinner.Success(fmt.Sprintf("Removed %v secrets", mappingResponse.RemovedSecrets)) - resourceExtractionSpinner.UpdateText(fmt.Sprintf("Extracted %v changing resources: %v supported %v skipped %v unsupported\n", + resourceExtractionSpinner.UpdateText(fmt.Sprintf("Extracted %v changing resources: %v supported %v skipped %v unsupported %v pending creation\n", mappingResponse.NumTotal(), mappingResponse.NumSuccess(), mappingResponse.NumNotEnoughInfo(), mappingResponse.NumUnsupported(), + mappingResponse.NumPendingCreation(), )) // Sort the supported and unsupported changes so that they display nicely @@ -174,6 +175,8 @@ func TerraformPlanImpl(ctx context.Context, cmd *cobra.Command, oi sdp.OvermindI printer = pterm.Warning case tfutils.MapStatusUnsupported: printer = pterm.Error + case tfutils.MapStatusPendingCreation: + printer = pterm.Info } line := printer.Sprintf("%v (%v)", mapping.TerraformName, mapping.Message) diff --git a/discovery/adapter.go b/discovery/adapter.go index db37cc53..5d0b473d 100644 --- a/discovery/adapter.go +++ b/discovery/adapter.go @@ -86,11 +86,15 @@ type HiddenAdapter interface { } // WildcardScopeAdapter is an optional interface that adapters can implement -// to declare they can handle "*" wildcard scopes efficiently (e.g., using -// GCP's aggregatedList API). When an adapter implements this interface and -// returns true from SupportsWildcardScope(), the engine will pass wildcard -// scopes directly to the adapter instead of expanding them to all configured -// scopes. +// to declare they can handle "*" wildcard scopes efficiently for LIST queries +// (e.g., using GCP's aggregatedList API). When an adapter implements this +// interface and returns true from SupportsWildcardScope(), the engine will +// pass wildcard scopes directly to the adapter instead of expanding them to +// all configured scopes—but only for LIST queries. +// +// For GET and SEARCH, the engine always expands wildcard scope so that +// multiple results can be returned when a resource exists in multiple scopes. +// Future work may extend this optimization to SEARCH once adapters support it. type WildcardScopeAdapter interface { Adapter SupportsWildcardScope() bool diff --git a/discovery/adapter_test.go b/discovery/adapter_test.go index 4b427c1c..34011567 100644 --- a/discovery/adapter_test.go +++ b/discovery/adapter_test.go @@ -36,7 +36,7 @@ func TestGet(t *testing.T) { "test", "empty", }, - cache: sdpcache.NewCache(t.Context()), + cache: sdpcache.NewMemoryCache(), } e := newStartedEngine(t, "TestGet", nil, nil, &adapter) @@ -146,7 +146,6 @@ func TestGet(t *testing.T) { } time.Sleep(10 * time.Millisecond) - e.sh.Purge(t.Context()) item3, _, _, err = e.executeQuerySync(context.Background(), &req) if err != nil { @@ -251,6 +250,7 @@ func TestGet(t *testing.T) { func TestList(t *testing.T) { adapter := TestAdapter{} + adapter.cache = sdpcache.NewMemoryCache() e := newStartedEngine(t, "TestList", nil, nil, &adapter) @@ -276,6 +276,7 @@ func TestList(t *testing.T) { func TestSearch(t *testing.T) { adapter := TestAdapter{} + adapter.cache = sdpcache.NewMemoryCache() e := newStartedEngine(t, "TestSearch", nil, nil, &adapter) @@ -307,7 +308,7 @@ func TestListSearchCaching(t *testing.T) { "empty", "error", }, - cache: sdpcache.NewCache(t.Context()), + cache: sdpcache.NewMemoryCache(), } e := newStartedEngine(t, "TestListSearchCaching", nil, nil, &adapter) @@ -344,7 +345,6 @@ func TestListSearchCaching(t *testing.T) { } time.Sleep(10 * time.Millisecond) - e.sh.Purge(t.Context()) list3, _, _, err = e.executeQuerySync(context.Background(), &q) if err != nil { @@ -619,7 +619,7 @@ func TestSearchGetCaching(t *testing.T) { ReturnScopes: []string{ "test", }, - cache: sdpcache.NewCache(t.Context()), + cache: sdpcache.NewMemoryCache(), } e := newStartedEngine(t, "TestSearchGetCaching", nil, nil, &adapter) diff --git a/discovery/adapterhost.go b/discovery/adapterhost.go index 26167318..9d34eee4 100644 --- a/discovery/adapterhost.go +++ b/discovery/adapterhost.go @@ -1,12 +1,10 @@ package discovery import ( - "context" "errors" "fmt" "strings" "sync" - "time" "github.com/overmindtech/cli/sdp-go" log "github.com/sirupsen/logrus" @@ -125,9 +123,10 @@ func (sh *AdapterHost) AdaptersByType(typ string) []Adapter { // // The same goes for scopes, if we have a query with a wildcard scope, and // a single adapter that supports 5 scopes, we will end up with 5 queries. The -// exception to this is if we have a adapter that supports all scopes, but is -// unable to list them. In this case there will still be some queries with -// wildcard scopes as they can't be expanded +// exception to this is if we have an adapter that supports all scopes +// (implements WildcardScopeAdapter) and the query method is LIST. In that +// case we pass the wildcard scope directly to the adapter. For GET and +// SEARCH, we always expand so multiple results can be returned. // // This functions returns a map of queries with the adapters that they should be // run against @@ -159,8 +158,10 @@ func (sh *AdapterHost) ExpandQuery(q *sdp.Query) map[*sdp.Query]Adapter { } // If query has wildcard scope and adapter supports wildcards, - // create ONE query with wildcard scope (no expansion) - if supportsWildcard && IsWildcard(q.GetScope()) && !isHidden { + // create ONE query with wildcard scope (no expansion). + // Only for LIST: GET and SEARCH must expand so we can return + // multiple results when a resource exists in multiple scopes. + if supportsWildcard && IsWildcard(q.GetScope()) && !isHidden && q.GetMethod() == sdp.QueryMethod_LIST { dest := proto.Clone(q).(*sdp.Query) dest.Type = adapter.Type() // specialise the query to the adapter type expandedQueries[dest] = adapter @@ -200,26 +201,3 @@ func (sh *AdapterHost) ClearAllAdapters() { sh.adapterIndex = make(map[string]map[string]bool) sh.mutex.Unlock() } - -func (sh *AdapterHost) Purge(ctx context.Context) { - for _, s := range sh.Adapters() { - if c, ok := s.(CachingAdapter); ok { - cache := c.Cache() - if cache != nil { - cache.Purge(ctx, time.Now()) - } - } - } -} - -// ClearCaches Clears caches for all caching adapters -func (sh *AdapterHost) ClearCaches() { - for _, s := range sh.Adapters() { - if c, ok := s.(CachingAdapter); ok { - cache := c.Cache() - if cache != nil { - cache.Clear() - } - } - } -} diff --git a/discovery/adapterhost_bench_test.go b/discovery/adapterhost_bench_test.go index 0132391a..fc6b09da 100644 --- a/discovery/adapterhost_bench_test.go +++ b/discovery/adapterhost_bench_test.go @@ -252,9 +252,7 @@ func (b *BenchmarkListAdapter) List(ctx context.Context, scope string, ignoreCac itemsPerList = 10 // Default to 10 items } - // Check cache first using the embedded TestAdapter's cache - cache := b.Cache() - cacheHit, ck, cachedItems, qErr, done := cache.Lookup(ctx, b.Name(), sdp.QueryMethod_LIST, scope, b.Type(), "", ignoreCache) + cacheHit, ck, cachedItems, qErr, done := b.cache.Lookup(ctx, b.Name(), sdp.QueryMethod_LIST, scope, b.Type(), "", ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -274,7 +272,7 @@ func (b *BenchmarkListAdapter) List(ctx context.Context, scope string, ignoreCac ErrorString: "no items found", Scope: scope, } - cache.StoreError(ctx, err, b.DefaultCacheDuration(), ck) + b.cache.StoreError(ctx, err, b.DefaultCacheDuration(), ck) return nil, err case "error": return nil, &sdp.QueryError{ @@ -288,7 +286,7 @@ func (b *BenchmarkListAdapter) List(ctx context.Context, scope string, ignoreCac for i := range itemsPerList { item := b.NewTestItem(scope, fmt.Sprintf("item-%d", i)) items = append(items, item) - cache.StoreItem(ctx, item, b.DefaultCacheDuration(), ck) + b.cache.StoreItem(ctx, item, b.DefaultCacheDuration(), ck) } return items, nil } @@ -363,6 +361,7 @@ func newBenchmarkEngine(adapters ...Adapter) (*Engine, error) { MaxParallelExecutions: 2000, SourceName: "benchmark-engine", NATSQueueName: "", + Unauthenticated: true, // No NATSOptions - we don't need NATS for benchmarks } diff --git a/discovery/adapterhost_test.go b/discovery/adapterhost_test.go index 069881c1..c706954f 100644 --- a/discovery/adapterhost_test.go +++ b/discovery/adapterhost_test.go @@ -234,10 +234,11 @@ func TestAdapterHostExpandQuery_WildcardScope(t *testing.T) { } }) - t.Run("Wildcard-supporting adapter with wildcard scope does not expand", func(t *testing.T) { + t.Run("Wildcard-supporting adapter with wildcard scope does not expand for LIST", func(t *testing.T) { req := sdp.Query{ - Type: "wildcard-type", - Scope: sdp.WILDCARD, + Type: "wildcard-type", + Method: sdp.QueryMethod_LIST, + Scope: sdp.WILDCARD, } expanded := sh.ExpandQuery(&req) @@ -255,6 +256,50 @@ func TestAdapterHostExpandQuery_WildcardScope(t *testing.T) { } }) + t.Run("Wildcard-supporting adapter with wildcard scope expands for GET", func(t *testing.T) { + req := sdp.Query{ + Type: "wildcard-type", + Method: sdp.QueryMethod_GET, + Scope: sdp.WILDCARD, + } + + expanded := sh.ExpandQuery(&req) + + // Should expand to 2 queries (one per scope) for GET + if len(expanded) != 2 { + t.Fatalf("Expected 2 expanded queries for wildcard adapter with GET, got %v", len(expanded)) + } + + // Check that scopes are specific, not wildcard + for q := range expanded { + if q.GetScope() == sdp.WILDCARD { + t.Errorf("Expected specific scope for GET, got wildcard") + } + } + }) + + t.Run("Wildcard-supporting adapter with wildcard scope expands for SEARCH", func(t *testing.T) { + req := sdp.Query{ + Type: "wildcard-type", + Method: sdp.QueryMethod_SEARCH, + Scope: sdp.WILDCARD, + } + + expanded := sh.ExpandQuery(&req) + + // Should expand to 2 queries (one per scope) for SEARCH + if len(expanded) != 2 { + t.Fatalf("Expected 2 expanded queries for wildcard adapter with SEARCH, got %v", len(expanded)) + } + + // Check that scopes are specific, not wildcard + for q := range expanded { + if q.GetScope() == sdp.WILDCARD { + t.Errorf("Expected specific scope for SEARCH, got wildcard") + } + } + }) + t.Run("Wildcard-supporting adapter with specific scope works normally", func(t *testing.T) { req := sdp.Query{ Type: "wildcard-type", diff --git a/discovery/cmd.go b/discovery/cmd.go index 5d359048..27b2f863 100644 --- a/discovery/cmd.go +++ b/discovery/cmd.go @@ -226,73 +226,98 @@ func MapFromEngineConfig(ec *EngineConfig) map[string]any { } } -// CreateClients we need to have some checks, as it is called by the cli tool +// CreateClients sets up NATS TokenClient and HeartbeatOptions.ManagementClient from config. +// Each client is only created if not already set (idempotent), so callers like the CLI +// can pre-configure clients without them being overwritten. func (ec *EngineConfig) CreateClients() error { // If we are running in unauthenticated mode then do nothing here if ec.Unauthenticated { log.Warn("Using unauthenticated NATS as ALLOW_UNAUTHENTICATED is set") - log.WithFields(MapFromEngineConfig(ec)).Info("Engine config") + if ec.NATSOptions != nil { + log.WithFields(MapFromEngineConfig(ec)).Info("Engine config") + } + return nil + } + + // If both clients are already configured (e.g. CLI), skip entirely + if ec.NATSOptions != nil && ec.NATSOptions.TokenClient != nil && + ec.HeartbeatOptions != nil && ec.HeartbeatOptions.ManagementClient != nil { return nil } switch ec.OvermindManagedSource { case sdp.SourceManaged_LOCAL: log.Info("Using API Key for authentication, heartbeats will be sent") - tokenClient, err := auth.NewAPIKeyClient(ec.APIServerURL, ec.ApiKey) - if err != nil { - err = fmt.Errorf("error creating API key client %w", err) - return err - } - tokenSource := auth.NewAPIKeyTokenSource(ec.ApiKey, ec.APIServerURL) - transport := oauth2.Transport{ - Source: tokenSource, - Base: http.DefaultTransport, + + if ec.NATSOptions != nil && ec.NATSOptions.TokenClient == nil { + tokenClient, err := auth.NewAPIKeyClient(ec.APIServerURL, ec.ApiKey) + if err != nil { + return fmt.Errorf("error creating API key client: %w", err) + } + ec.NATSOptions.TokenClient = tokenClient } - authenticatedClient := http.Client{ - Transport: otelhttp.NewTransport(&transport), + + if ec.HeartbeatOptions == nil { + ec.HeartbeatOptions = &HeartbeatOptions{} } - heartbeatOptions := HeartbeatOptions{ - ManagementClient: sdpconnect.NewManagementServiceClient( + if ec.HeartbeatOptions.ManagementClient == nil { + tokenSource := auth.NewAPIKeyTokenSource(ec.ApiKey, ec.APIServerURL) + transport := oauth2.Transport{ + Source: tokenSource, + Base: http.DefaultTransport, + } + authenticatedClient := http.Client{ + Transport: otelhttp.NewTransport(&transport), + } + ec.HeartbeatOptions.ManagementClient = sdpconnect.NewManagementServiceClient( &authenticatedClient, ec.APIServerURL, - ), - Frequency: time.Second * 30, + ) + ec.HeartbeatOptions.Frequency = time.Second * 30 + } + + if ec.NATSOptions != nil { + log.WithFields(MapFromEngineConfig(ec)).Info("Engine config") } - ec.HeartbeatOptions = &heartbeatOptions - ec.NATSOptions.TokenClient = tokenClient - // lets print out the config - log.WithFields(MapFromEngineConfig(ec)).Info("Engine config") return nil case sdp.SourceManaged_MANAGED: log.Info("Using static token for authentication, heartbeats will be sent") - tokenClient, err := auth.NewStaticTokenClient(ec.APIServerURL, ec.SourceAccessToken, ec.SourceAccessTokenType) - if err != nil { - err = fmt.Errorf("error creating static token client %w", err) - sentry.CaptureException(err) - return err - } - tokenSource := oauth2.StaticTokenSource(&oauth2.Token{ - AccessToken: ec.SourceAccessToken, - TokenType: ec.SourceAccessTokenType, - }) - transport := oauth2.Transport{ - Source: tokenSource, - Base: http.DefaultTransport, + + if ec.NATSOptions != nil && ec.NATSOptions.TokenClient == nil { + tokenClient, err := auth.NewStaticTokenClient(ec.APIServerURL, ec.SourceAccessToken, ec.SourceAccessTokenType) + if err != nil { + err = fmt.Errorf("error creating static token client: %w", err) + sentry.CaptureException(err) + return err + } + ec.NATSOptions.TokenClient = tokenClient } - authenticatedClient := http.Client{ - Transport: otelhttp.NewTransport(&transport), + + if ec.HeartbeatOptions == nil { + ec.HeartbeatOptions = &HeartbeatOptions{} } - heartbeatOptions := HeartbeatOptions{ - ManagementClient: sdpconnect.NewManagementServiceClient( + if ec.HeartbeatOptions.ManagementClient == nil { + tokenSource := oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: ec.SourceAccessToken, + TokenType: ec.SourceAccessTokenType, + }) + transport := oauth2.Transport{ + Source: tokenSource, + Base: http.DefaultTransport, + } + authenticatedClient := http.Client{ + Transport: otelhttp.NewTransport(&transport), + } + ec.HeartbeatOptions.ManagementClient = sdpconnect.NewManagementServiceClient( &authenticatedClient, ec.APIServerURL, - ), - Frequency: time.Second * 30, + ) + ec.HeartbeatOptions.Frequency = time.Second * 30 + } + + if ec.NATSOptions != nil { + log.WithFields(MapFromEngineConfig(ec)).Info("Engine config") } - ec.NATSOptions.TokenClient = tokenClient - ec.HeartbeatOptions = &heartbeatOptions - // lets print out the config - log.WithFields(MapFromEngineConfig(ec)).Info("Engine config") return nil } diff --git a/discovery/doc.go b/discovery/doc.go new file mode 100644 index 00000000..ab24ca1e --- /dev/null +++ b/discovery/doc.go @@ -0,0 +1,33 @@ +// Package discovery provides the engine and protocol types for Overmind sources. +// Sources discover infrastructure (AWS, K8s, GCP, etc.) and respond to queries via NATS. +// +// # Startup sequence for source authors +// +// Sources should follow this canonical flow so that health probes and heartbeats +// work even when adapter initialization fails (avoiding CrashLoopBackOff): +// +// 1. EngineConfigFromViper(engineType, version) — fail: return/exit +// 2. NewEngine(engineConfig) — fail: return/exit (includes CreateClients internally) +// 3. ServeHealthProbes(port) +// 4. Start(ctx) — fail: return/exit (NATS connection required) +// 5. Validate source config — permanent config errors: SetInitError(err), then idle +// 6. Adapter init — use InitialiseAdapters (blocks until success or ctx cancelled) for retryable init, or SetInitError for single-attempt +// 7. Wait for SIGTERM, then Stop() +// +// # Error handling +// +// Fatal errors (caller must return or exit): EngineConfigFromViper, NewEngine, Start. +// The engine cannot function without a valid config, auth clients, or NATS connection. +// +// Recoverable errors (call SetInitError and keep running): source config validation +// failures (e.g. missing credentials, invalid regions) and adapter initialization +// failures that may be transient. The pod stays Running, readiness fails, and the +// error is reported via heartbeats and the API/UI. +// +// Permanent config errors (e.g. invalid API key, missing required flags) should +// be detected before calling InitialiseAdapters and reported via SetInitError — +// do not retry. Transient adapter init errors (e.g. upstream API temporarily +// unavailable) should use InitialiseAdapters, which retries with backoff. +// +// See SetInitError and InitialiseAdapters for details and examples. +package discovery diff --git a/discovery/engine.go b/discovery/engine.go index 61fa08d0..b9ea9e4a 100644 --- a/discovery/engine.go +++ b/discovery/engine.go @@ -11,6 +11,7 @@ import ( "time" "connectrpc.com/connect" + "github.com/cenkalti/backoff/v5" "github.com/getsentry/sentry-go" "github.com/google/uuid" "github.com/nats-io/nats.go" @@ -149,9 +150,23 @@ type Engine struct { lastSuccessfulHeartbeat time.Time lastHeartbeatError error heartbeatStatusMutex sync.RWMutex + + // initError stores configuration/credential/initialization failures that prevent + // adapters from being added to the engine. This includes: + // - AWS: AssumeRole failures, GetCallerIdentity errors, invalid credentials + // - K8s: Namespace listing failures, kubeconfig errors + // - Harness: API authentication failures, hierarchy discovery errors + // The error is surfaced via readiness checks (pod becomes 0/1 Ready) and + // heartbeats (visible in UI/API), allowing the pod to stay Running instead of + // CrashLoopBackOff so customers can diagnose and fix configuration issues. + initError error + initErrorMutex sync.RWMutex } func NewEngine(engineConfig *EngineConfig) (*Engine, error) { + if err := engineConfig.CreateClients(); err != nil { + return nil, fmt.Errorf("could not create auth clients: %w", err) + } sh := NewAdapterHost() return &Engine{ EngineConfig: engineConfig, @@ -325,7 +340,8 @@ func (e *Engine) Start(ctx context.Context) error { err := e.connect() //nolint:contextcheck // context is passed in through backgroundJobContext if err != nil { - return e.SendHeartbeat(e.backgroundJobContext, err) //nolint:contextcheck + _ = e.SendHeartbeat(e.backgroundJobContext, err) //nolint:contextcheck + return fmt.Errorf("could not connect to NATS: %w", err) } // Start background jobs @@ -374,9 +390,6 @@ func (e *Engine) Stop() error { if e.heartbeatCancel != nil { e.heartbeatCancel() } - - e.sh.ClearCaches() - return nil } @@ -488,6 +501,11 @@ func (e *Engine) ReadinessHealthCheck(ctx context.Context) error { attribute.String("ovm.healthcheck.type", "readiness"), ) + // Check for persistent initialization errors first + if initErr := e.GetInitError(); initErr != nil { + return fmt.Errorf("source initialization failed: %w", initErr) + } + // Check adapter-specific health using the ReadinessCheck function if e.EngineConfig.HeartbeatOptions != nil && e.EngineConfig.HeartbeatOptions.ReadinessCheck != nil { if err := e.EngineConfig.HeartbeatOptions.ReadinessCheck(ctx); err != nil { @@ -660,11 +678,6 @@ func (e *Engine) HandleLogRecordsRequestWithErrors(ctx context.Context, replyTo return nil } -// ClearCache Completely clears the cache -func (e *Engine) ClearCache() { - e.sh.ClearCaches() -} - // ClearAdapters Deletes all adapters from the engine, allowing new adapters to be // added using `AddAdapter()`. Note that this requires a restart using // `Restart()` in order to take effect @@ -725,6 +738,92 @@ func (e *Engine) AdaptersByType(typ string) []Adapter { return e.sh.AdaptersByType(typ) } +// SetInitError stores a persistent initialization error that will be reported via heartbeat and readiness checks. +// This should be called when source initialization fails in a way that prevents adapters from being added, +// but the process should continue running to serve probes and heartbeats (avoiding CrashLoopBackOff). +// +// Pass nil to clear a previously set error (e.g. after successful retry/restart). +// +// Example usage: +// +// if err := initializeAdapters(); err != nil { +// e.SetInitError(fmt.Errorf("adapter initialization failed: %w", err)) +// // Continue running - pod stays Running with readiness failing +// } +func (e *Engine) SetInitError(err error) { + e.initErrorMutex.Lock() + defer e.initErrorMutex.Unlock() + e.initError = err +} + +// GetInitError returns the persistent initialization error if any. +// Returns nil if no init error is set or if it was cleared via SetInitError(nil). +func (e *Engine) GetInitError() error { + e.initErrorMutex.RLock() + defer e.initErrorMutex.RUnlock() + return e.initError +} + +// InitialiseAdapters retries initFn with exponential backoff (capped at +// 5 minutes) until it succeeds or ctx is cancelled. It blocks the caller. +// +// This is intended for adapter initialization that makes API calls to upstream +// services and may fail transiently. Because it blocks, the caller can +// safely set up namespace watches or other reload mechanisms after it returns +// without racing against a background retry goroutine. +// +// On each attempt: +// - ClearAdapters() is called to remove any leftovers from previous attempts. +// - initFn is called. The init error is updated via SetInitError immediately +// (cleared on success, set on failure) and then a heartbeat is sent so the +// API/UI always reflects the current status. +// - On success, StartSendingHeartbeats is called and the function returns. +// +// The caller should have already called Start() before calling this. +func (e *Engine) InitialiseAdapters(ctx context.Context, initFn func(ctx context.Context) error) { + b := backoff.NewExponentialBackOff() + b.MaxInterval = 5 * time.Minute + tick := backoff.NewTicker(b) + defer tick.Stop() + + for { + select { + case <-ctx.Done(): + return + case _, ok := <-tick.C: + if !ok { + // Backoff exhausted (shouldn't happen with default MaxElapsedTime=0) + return + } + + e.ClearAdapters() + + err := initFn(ctx) + + if err != nil { + e.SetInitError(fmt.Errorf("adapter initialisation failed: %w", err)) + log.WithError(err).Warn("Adapter initialisation failed, will retry") + } else { + // Clear any previous init error before the heartbeat so the + // API/UI immediately sees the healthy status. + e.SetInitError(nil) + } + + // Send heartbeat regardless of outcome so the API/UI reflects current status + if hbErr := e.SendHeartbeat(ctx, nil); hbErr != nil { + log.WithError(hbErr).Error("Error sending heartbeat during adapter initialisation") + } + + if err != nil { + continue + } + + e.StartSendingHeartbeats(ctx) + return + } + } +} + // LivenessProbeHandlerFunc returns an HTTP handler function for liveness probes. // This checks only engine initialization (NATS connection, heartbeats) and does NOT check adapter-specific health. func (e *Engine) LivenessProbeHandlerFunc() func(http.ResponseWriter, *http.Request) { diff --git a/discovery/engine_initerror_test.go b/discovery/engine_initerror_test.go new file mode 100644 index 00000000..37ed2404 --- /dev/null +++ b/discovery/engine_initerror_test.go @@ -0,0 +1,365 @@ +package discovery + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "testing" + "time" + + "connectrpc.com/connect" + "github.com/overmindtech/cli/sdp-go" +) + +func TestSetInitError(t *testing.T) { + e := &Engine{ + initError: nil, + initErrorMutex: sync.RWMutex{}, + } + + testErr := errors.New("initialization failed") + e.SetInitError(testErr) + + // Direct pointer comparison is intentional here - we want to verify the exact error object is stored + if e.initError == nil || e.initError.Error() != testErr.Error() { + t.Errorf("expected initError to be %v, got %v", testErr, e.initError) + } +} + +func TestGetInitError(t *testing.T) { + e := &Engine{ + initError: nil, + initErrorMutex: sync.RWMutex{}, + } + + // Test nil case + if err := e.GetInitError(); err != nil { + t.Errorf("expected nil error, got %v", err) + } + + // Test with error set + testErr := errors.New("test error") + e.initError = testErr + + if err := e.GetInitError(); err == nil || err.Error() != testErr.Error() { + t.Errorf("expected error to be %v, got %v", testErr, err) + } +} + +func TestSetInitErrorNil(t *testing.T) { + e := &Engine{ + initError: errors.New("previous error"), + initErrorMutex: sync.RWMutex{}, + } + + // Clear the error + e.SetInitError(nil) + + if e.initError != nil { + t.Errorf("expected initError to be nil after clearing, got %v", e.initError) + } + + if err := e.GetInitError(); err != nil { + t.Errorf("expected GetInitError to return nil after clearing, got %v", err) + } +} + +func TestInitErrorConcurrentAccess(t *testing.T) { + e := &Engine{ + initError: nil, + initErrorMutex: sync.RWMutex{}, + } + + // Test concurrent access from multiple goroutines + var wg sync.WaitGroup + iterations := 100 + + // Writers + for i := range 10 { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := range iterations { + e.SetInitError(fmt.Errorf("error from goroutine %d iteration %d", id, j)) + } + }(i) + } + + // Readers + for range 10 { + wg.Add(1) + go func() { + defer wg.Done() + for range iterations { + _ = e.GetInitError() + } + }() + } + + wg.Wait() + + // Should not panic - error should be one of the written values or nil + finalErr := e.GetInitError() + if finalErr == nil { + t.Log("Final error is nil (acceptable in concurrent test)") + } else { + t.Logf("Final error: %v", finalErr) + } +} + +func TestReadinessHealthCheckWithInitError(t *testing.T) { + ec := &EngineConfig{ + EngineType: "test", + SourceName: "test-source", + HeartbeatOptions: &HeartbeatOptions{ + ReadinessCheck: func(ctx context.Context) error { + // Adapter health is fine + return nil + }, + }, + } + + e, err := NewEngine(ec) + if err != nil { + t.Fatalf("failed to create engine: %v", err) + } + + ctx := context.Background() + + // Readiness should pass when no init error + if err := e.ReadinessHealthCheck(ctx); err != nil { + t.Errorf("expected readiness to pass with no init error, got: %v", err) + } + + // Set an init error + testErr := errors.New("AWS AssumeRole denied") + e.SetInitError(testErr) + + // Readiness should now fail with the init error + err = e.ReadinessHealthCheck(ctx) + if err == nil { + t.Error("expected readiness to fail with init error, got nil") + } else if !errors.Is(err, testErr) { + t.Errorf("expected readiness error to wrap init error, got: %v", err) + } + + // Clear the init error + e.SetInitError(nil) + + // Readiness should pass again + if err := e.ReadinessHealthCheck(ctx); err != nil { + t.Errorf("expected readiness to pass after clearing init error, got: %v", err) + } +} + +func TestSendHeartbeatWithInitError(t *testing.T) { + requests := make(chan *connect.Request[sdp.SubmitSourceHeartbeatRequest], 10) + responses := make(chan *connect.Response[sdp.SubmitSourceHeartbeatResponse], 10) + + ec := &EngineConfig{ + EngineType: "test", + SourceName: "test-source", + HeartbeatOptions: &HeartbeatOptions{ + ManagementClient: testHeartbeatClient{ + Requests: requests, + Responses: responses, + }, + Frequency: 0, // Disable automatic heartbeats + ReadinessCheck: func(ctx context.Context) error { + return nil // Adapters are fine + }, + }, + } + + e, err := NewEngine(ec) + if err != nil { + t.Fatalf("failed to create engine: %v", err) + } + + ctx := context.Background() + + // Send heartbeat with init error + testErr := errors.New("configuration error: invalid credentials") + e.SetInitError(testErr) + + responses <- &connect.Response[sdp.SubmitSourceHeartbeatResponse]{ + Msg: &sdp.SubmitSourceHeartbeatResponse{}, + } + + err = e.SendHeartbeat(ctx, nil) + if err != nil { + t.Errorf("expected SendHeartbeat to succeed, got: %v", err) + } + + // Verify the heartbeat included the init error + req := <-requests + if req.Msg.GetError() == "" { + t.Error("expected heartbeat to include error, got empty string") + } else if !strings.Contains(req.Msg.GetError(), testErr.Error()) { + t.Errorf("expected heartbeat error to contain %q, got: %q", testErr.Error(), req.Msg.GetError()) + } +} + +func TestSendHeartbeatWithInitErrorAndCustomError(t *testing.T) { + requests := make(chan *connect.Request[sdp.SubmitSourceHeartbeatRequest], 10) + responses := make(chan *connect.Response[sdp.SubmitSourceHeartbeatResponse], 10) + + ec := &EngineConfig{ + EngineType: "test", + SourceName: "test-source", + HeartbeatOptions: &HeartbeatOptions{ + ManagementClient: testHeartbeatClient{ + Requests: requests, + Responses: responses, + }, + Frequency: 0, + }, + } + + e, err := NewEngine(ec) + if err != nil { + t.Fatalf("failed to create engine: %v", err) + } + + ctx := context.Background() + + // Set init error and send heartbeat with custom error + initErr := errors.New("init failed: invalid config") + customErr := errors.New("custom error: readiness failed") + e.SetInitError(initErr) + + responses <- &connect.Response[sdp.SubmitSourceHeartbeatResponse]{ + Msg: &sdp.SubmitSourceHeartbeatResponse{}, + } + + err = e.SendHeartbeat(ctx, customErr) + if err != nil { + t.Errorf("expected SendHeartbeat to succeed, got: %v", err) + } + + // Verify both errors are included in the heartbeat + req := <-requests + if req.Msg.GetError() == "" { + t.Error("expected heartbeat to include errors, got empty string") + } else { + errMsg := req.Msg.GetError() + // Both errors should be in the joined error string + if !strings.Contains(errMsg, initErr.Error()) { + t.Errorf("expected heartbeat error to include init error %q, got: %q", initErr.Error(), errMsg) + } + if !strings.Contains(errMsg, customErr.Error()) { + t.Errorf("expected heartbeat error to include custom error %q, got: %q", customErr.Error(), errMsg) + } + } +} + +func TestInitialiseAdapters_Success(t *testing.T) { + ec := &EngineConfig{ + EngineType: "test", + SourceName: "test-source", + HeartbeatOptions: &HeartbeatOptions{ + Frequency: 0, // Disable automatic heartbeats from StartSendingHeartbeats + }, + } + e, err := NewEngine(ec) + if err != nil { + t.Fatalf("failed to create engine: %v", err) + } + + // Set an init error to verify it gets cleared on success + e.SetInitError(errors.New("previous error")) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + var called bool + e.InitialiseAdapters(ctx, func(ctx context.Context) error { + called = true + return nil + }) + + if !called { + t.Error("initFn was not called") + } + if err := e.GetInitError(); err != nil { + t.Errorf("expected init error to be cleared after success, got: %v", err) + } +} + +func TestInitialiseAdapters_RetryThenSuccess(t *testing.T) { + ec := &EngineConfig{ + EngineType: "test", + SourceName: "test-source", + HeartbeatOptions: &HeartbeatOptions{ + Frequency: 0, + }, + } + e, err := NewEngine(ec) + if err != nil { + t.Fatalf("failed to create engine: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + attempts := 0 + e.InitialiseAdapters(ctx, func(ctx context.Context) error { + attempts++ + if attempts < 3 { + return fmt.Errorf("transient error attempt %d", attempts) + } + return nil + }) + + if attempts < 3 { + t.Errorf("expected at least 3 attempts, got %d", attempts) + } + if err := e.GetInitError(); err != nil { + t.Errorf("expected init error to be cleared after eventual success, got: %v", err) + } +} + +func TestInitialiseAdapters_ContextCancelled(t *testing.T) { + ec := &EngineConfig{ + EngineType: "test", + SourceName: "test-source", + HeartbeatOptions: &HeartbeatOptions{ + Frequency: 0, + }, + } + e, err := NewEngine(ec) + if err != nil { + t.Fatalf("failed to create engine: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + var callCount int + + // InitialiseAdapters blocks; cancel ctx after a short delay so it returns + time.AfterFunc(500*time.Millisecond, cancel) + + done := make(chan struct{}) + go func() { + e.InitialiseAdapters(ctx, func(ctx context.Context) error { + callCount++ + return errors.New("always fails") + }) + close(done) + }() + + select { + case <-done: + // InitialiseAdapters returned (ctx was cancelled) + case <-time.After(5 * time.Second): + t.Fatal("InitialiseAdapters did not return after context cancellation") + } + + if callCount == 0 { + t.Error("expected initFn to be called at least once before context cancellation") + } + if err := e.GetInitError(); err == nil { + t.Error("expected init error to be set after context cancellation with failures") + } +} diff --git a/discovery/engine_test.go b/discovery/engine_test.go index 63343d37..8dd7ad6d 100644 --- a/discovery/engine_test.go +++ b/discovery/engine_test.go @@ -13,6 +13,7 @@ import ( "github.com/nats-io/nats-server/v2/test" "github.com/overmindtech/cli/auth" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" "golang.org/x/oauth2" ) @@ -30,6 +31,9 @@ func newEngine(t *testing.T, name string, no *auth.NATSOptions, eConn sdp.Encode } if no != nil { ec.NATSOptions = no + if no.TokenClient == nil { + ec.Unauthenticated = true + } } else if eConn == nil { ec.NATSOptions = &auth.NATSOptions{ NumRetries: 5, @@ -194,6 +198,7 @@ func TestNats(t *testing.T) { ec := EngineConfig{ MaxParallelExecutions: 10, SourceName: "nats-test", + Unauthenticated: true, NATSOptions: &auth.NATSOptions{ NumRetries: 5, RetryDelay: time.Second, @@ -211,7 +216,7 @@ func TestNats(t *testing.T) { } adapter := TestAdapter{} - + adapter.cache = sdpcache.NewNoOpCache() err = e.AddAdapters( &adapter, &TestAdapter{ @@ -220,6 +225,7 @@ func TestNats(t *testing.T) { }, ReturnName: "test-adapter", ReturnType: "test", + cache: sdpcache.NewNoOpCache(), }, ) if err != nil { @@ -256,7 +262,6 @@ func TestNats(t *testing.T) { t.Run("Handling a basic query", func(t *testing.T) { t.Cleanup(func() { adapter.ClearCalls() - e.ClearCache() }) query := &sdp.Query{ @@ -293,6 +298,7 @@ func TestNatsCancel(t *testing.T) { ec := EngineConfig{ MaxParallelExecutions: 1, SourceName: "nats-test", + Unauthenticated: true, NATSOptions: &auth.NATSOptions{ NumRetries: 5, RetryDelay: time.Second, @@ -379,6 +385,7 @@ func TestNatsConnections(t *testing.T) { ec := EngineConfig{ MaxParallelExecutions: 1, SourceName: "nats-test", + Unauthenticated: true, NATSOptions: &auth.NATSOptions{ Servers: []string{"nats://bad.server"}, ConnectionName: "test-disconnection", @@ -419,6 +426,7 @@ func TestNatsConnections(t *testing.T) { ec := EngineConfig{ MaxParallelExecutions: 1, SourceName: "nats-test", + Unauthenticated: true, NATSOptions: &auth.NATSOptions{ NumRetries: 5, RetryDelay: time.Second, @@ -485,6 +493,7 @@ func TestNatsConnections(t *testing.T) { ec := EngineConfig{ MaxParallelExecutions: 1, SourceName: "nats-test", + Unauthenticated: true, NATSOptions: &auth.NATSOptions{ NumRetries: 10, RetryDelay: time.Second, @@ -539,6 +548,7 @@ func TestNATSFailureRestart(t *testing.T) { ec := EngineConfig{ MaxParallelExecutions: 1, SourceName: "nats-test", + Unauthenticated: true, NATSOptions: &auth.NATSOptions{ NumRetries: 10, RetryDelay: time.Second, @@ -621,6 +631,7 @@ func TestNatsAuth(t *testing.T) { } adapter := TestAdapter{} + adapter.cache = sdpcache.NewNoOpCache() if err := e.AddAdapters( &adapter, &TestAdapter{ @@ -629,6 +640,7 @@ func TestNatsAuth(t *testing.T) { }, ReturnType: "test", ReturnName: "test-adapter", + cache: sdpcache.NewNoOpCache(), }, ); err != nil { t.Fatalf("Error adding adapters: %v", err) @@ -648,7 +660,6 @@ func TestNatsAuth(t *testing.T) { t.Run("Handling a basic query", func(t *testing.T) { t.Cleanup(func() { adapter.ClearCalls() - e.ClearCache() }) query := &sdp.Query{ diff --git a/discovery/enginerequests_test.go b/discovery/enginerequests_test.go index f7593dd7..8dfb2787 100644 --- a/discovery/enginerequests_test.go +++ b/discovery/enginerequests_test.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/overmindtech/cli/auth" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" "github.com/overmindtech/cli/tracing" "github.com/sourcegraph/conc/pool" "google.golang.org/protobuf/types/known/timestamppb" @@ -44,6 +45,7 @@ func TestExecuteQuery(t *testing.T) { adapter := TestAdapter{ ReturnType: "person", ReturnScopes: []string{"test"}, + cache: sdpcache.NewNoOpCache(), } e := newStartedEngine(t, "TestExecuteQuery", @@ -209,6 +211,7 @@ func TestHandleQuery(t *testing.T) { "test1", "test2", }, + cache: sdpcache.NewNoOpCache(), } dogAdapter := TestAdapter{ @@ -218,6 +221,7 @@ func TestHandleQuery(t *testing.T) { "testA", "testB", }, + cache: sdpcache.NewNoOpCache(), } e := newStartedEngine(t, "TestHandleQuery", nil, nil, &personAdapter, &dogAdapter) @@ -286,6 +290,7 @@ func TestWildcardAdapterExpansion(t *testing.T) { ReturnScopes: []string{ sdp.WILDCARD, }, + cache: sdpcache.NewNoOpCache(), } e := newStartedEngine(t, "TestWildcardAdapterExpansion", nil, nil, &personAdapter) @@ -337,6 +342,7 @@ func TestSendQuerySync(t *testing.T) { ReturnScopes: []string{ "test", }, + cache: sdpcache.NewNoOpCache(), } e := newStartedEngine(t, "TestSendQuerySync", nil, nil, &adapter) @@ -388,6 +394,7 @@ func TestExpandQuery(t *testing.T) { ReturnScopes: []string{ "test1", }, + cache: sdpcache.NewNoOpCache(), } e := newStartedEngine(t, "TestExpandQuery", nil, nil, &simple) @@ -411,6 +418,7 @@ func TestExpandQuery(t *testing.T) { "test2", "test3", }, + cache: sdpcache.NewNoOpCache(), } e := newStartedEngine(t, "TestExpandQuery", nil, nil, &many) @@ -433,6 +441,7 @@ func TestExpandQuery(t *testing.T) { ReturnScopes: []string{ sdp.WILDCARD, }, + cache: sdpcache.NewNoOpCache(), } e := newStartedEngine(t, "TestExpandQuery", nil, nil, &sx) diff --git a/discovery/heartbeat.go b/discovery/heartbeat.go index 0560962f..ee7563ad 100644 --- a/discovery/heartbeat.go +++ b/discovery/heartbeat.go @@ -42,8 +42,13 @@ func (e *Engine) SendHeartbeat(ctx context.Context, customErr error) error { return ErrNoHealthcheckDefined } + // No-op when running without management API (e.g. ALLOW_UNAUTHENTICATED local dev) if e.EngineConfig.HeartbeatOptions.ManagementClient == nil { - return errors.New("management client is not set") + log.WithFields(log.Fields{ + "source_name": e.EngineConfig.SourceName, + "engine_type": e.EngineConfig.EngineType, + }).Info("Running in unauthenticated mode; no heartbeats will be sent") + return nil } // Collect all health check errors @@ -52,6 +57,11 @@ func (e *Engine) SendHeartbeat(ctx context.Context, customErr error) error { allErrors = append(allErrors, customErr) } + // Check for persistent initialization errors first + if initErr := e.GetInitError(); initErr != nil { + allErrors = append(allErrors, initErr) + } + // Check adapter readiness (ReadinessCheck) - with timeout to prevent hanging if e.EngineConfig.HeartbeatOptions.ReadinessCheck != nil { // Add timeout for readiness checks to prevent hanging heartbeats diff --git a/discovery/heartbeat_test.go b/discovery/heartbeat_test.go index 5ea1edd6..74b79e0b 100644 --- a/discovery/heartbeat_test.go +++ b/discovery/heartbeat_test.go @@ -181,3 +181,27 @@ func TestHeartbeats(t *testing.T) { } }) } + +// TestSendHeartbeatNilManagementClient ensures unauthenticated/local dev mode +// (HeartbeatOptions set by SetReadinessCheck but ManagementClient nil) does not error. +func TestSendHeartbeatNilManagementClient(t *testing.T) { + ec := EngineConfig{ + SourceName: t.Name(), + SourceUUID: uuid.New(), + Version: "v0.0.0-test", + EngineType: "aws", + Unauthenticated: true, + HeartbeatOptions: &HeartbeatOptions{ + ManagementClient: nil, // e.g. ALLOW_UNAUTHENTICATED - no API to send to + Frequency: time.Second * 30, + }, + } + e, err := NewEngine(&ec) + if err != nil { + t.Fatalf("NewEngine: %v", err) + } + err = e.SendHeartbeat(context.Background(), nil) + if err != nil { + t.Errorf("SendHeartbeat with nil ManagementClient should be no-op, got: %v", err) + } +} diff --git a/discovery/performance_test.go b/discovery/performance_test.go index b10d8d65..a6b4bd8f 100644 --- a/discovery/performance_test.go +++ b/discovery/performance_test.go @@ -126,6 +126,7 @@ type TimedResults struct { func TimeQueries(t *testing.T, numQueries int, linkDepth int, numParallel int) TimedResults { ec := EngineConfig{ MaxParallelExecutions: numParallel, + Unauthenticated: true, NATSOptions: &auth.NATSOptions{ NumRetries: 5, RetryDelay: time.Second, diff --git a/discovery/querytracker_test.go b/discovery/querytracker_test.go index 530079c3..9a8a31a5 100644 --- a/discovery/querytracker_test.go +++ b/discovery/querytracker_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" "google.golang.org/protobuf/types/known/structpb" ) @@ -96,6 +97,7 @@ func TestExecute(t *testing.T) { ReturnScopes: []string{ "test", }, + cache: sdpcache.NewNoOpCache(), } e := newStartedEngine(t, "TestExecute", nil, nil, &adapter) diff --git a/discovery/shared_test.go b/discovery/shared_test.go index 3bcf4536..ce2943f5 100644 --- a/discovery/shared_test.go +++ b/discovery/shared_test.go @@ -87,9 +87,6 @@ func NewTestAdapter() *TestAdapter { } } -// assert interface implementation -var _ CachingAdapter = (*TestAdapter)(nil) - // ClearCalls Clears the call counters between tests func (s *TestAdapter) ClearCalls() { s.mutex.Lock() @@ -126,21 +123,6 @@ func (s *TestAdapter) Metadata() *sdp.AdapterMetadata { } } -var ( - noOpCacheTestOnce sync.Once - noOpCacheTest sdpcache.Cache -) - -func (s *TestAdapter) Cache() sdpcache.Cache { - if s.cache == nil { - noOpCacheTestOnce.Do(func() { - noOpCacheTest = sdpcache.NewNoOpCache() - }) - return noOpCacheTest - } - return s.cache -} - func (s *TestAdapter) Scopes() []string { if len(s.ReturnScopes) > 0 { return s.ReturnScopes @@ -163,7 +145,7 @@ func (s *TestAdapter) Get(ctx context.Context, scope string, query string, ignor var qErr *sdp.QueryError var done func() - cacheHit, ck, cachedItems, qErr, done = s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.Type(), query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done = s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.Type(), query, ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -185,7 +167,7 @@ func (s *TestAdapter) Get(ctx context.Context, scope string, query string, ignor ErrorString: "no items found", Scope: scope, } - s.Cache().StoreError(ctx, err, s.DefaultCacheDuration(), ck) + s.cache.StoreError(ctx, err, s.DefaultCacheDuration(), ck) return nil, err case "error": return nil, &sdp.QueryError{ @@ -195,7 +177,7 @@ func (s *TestAdapter) Get(ctx context.Context, scope string, query string, ignor } default: item := s.NewTestItem(scope, query) - s.Cache().StoreItem(ctx, item, s.DefaultCacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.DefaultCacheDuration(), ck) return item, nil } } @@ -210,7 +192,7 @@ func (s *TestAdapter) List(ctx context.Context, scope string, ignoreCache bool) var qErr *sdp.QueryError var done func() - cacheHit, ck, cachedItems, qErr, done = s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.Type(), "", ignoreCache) + cacheHit, ck, cachedItems, qErr, done = s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.Type(), "", ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -228,7 +210,7 @@ func (s *TestAdapter) List(ctx context.Context, scope string, ignoreCache bool) ErrorString: "no items found", Scope: scope, } - s.Cache().StoreError(ctx, err, s.DefaultCacheDuration(), ck) + s.cache.StoreError(ctx, err, s.DefaultCacheDuration(), ck) return nil, err case "error": return nil, &sdp.QueryError{ @@ -239,7 +221,7 @@ func (s *TestAdapter) List(ctx context.Context, scope string, ignoreCache bool) default: item := s.NewTestItem(scope, "Dylan") items := []*sdp.Item{item} - s.Cache().StoreItem(ctx, item, s.DefaultCacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.DefaultCacheDuration(), ck) return items, nil } } @@ -254,7 +236,7 @@ func (s *TestAdapter) Search(ctx context.Context, scope string, query string, ig var qErr *sdp.QueryError var done func() - cacheHit, ck, cachedItems, qErr, done = s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_SEARCH, scope, s.Type(), query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done = s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_SEARCH, scope, s.Type(), query, ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -272,7 +254,7 @@ func (s *TestAdapter) Search(ctx context.Context, scope string, query string, ig ErrorString: "no items found", Scope: scope, } - s.Cache().StoreError(ctx, err, s.DefaultCacheDuration(), ck) + s.cache.StoreError(ctx, err, s.DefaultCacheDuration(), ck) return nil, err case "error": return nil, &sdp.QueryError{ @@ -283,7 +265,7 @@ func (s *TestAdapter) Search(ctx context.Context, scope string, query string, ig default: item := s.NewTestItem(scope, "Dylan") items := []*sdp.Item{item} - s.Cache().StoreItem(ctx, item, s.DefaultCacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.DefaultCacheDuration(), ck) return items, nil } } diff --git a/go.mod b/go.mod index 03a4937e..c076d262 100644 --- a/go.mod +++ b/go.mod @@ -12,10 +12,10 @@ require ( atomicgo.dev/keyboard v0.2.9 buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.11-20251209175733-2a1774d88802.1 buf.build/go/protovalidate v1.1.0 - cloud.google.com/go/aiplatform v1.114.0 + cloud.google.com/go/aiplatform v1.115.0 cloud.google.com/go/auth v0.18.1 - cloud.google.com/go/bigquery v1.72.0 - cloud.google.com/go/bigtable v1.41.0 + cloud.google.com/go/bigquery v1.73.1 + cloud.google.com/go/bigtable v1.42.0 cloud.google.com/go/compute v1.54.0 cloud.google.com/go/compute/metadata v0.9.0 // indirect cloud.google.com/go/container v1.46.0 @@ -25,7 +25,7 @@ require ( cloud.google.com/go/functions v1.19.7 cloud.google.com/go/iam v1.5.3 cloud.google.com/go/kms v1.25.0 - cloud.google.com/go/logging v1.13.1 + cloud.google.com/go/logging v1.13.2 cloud.google.com/go/monitoring v1.24.3 cloud.google.com/go/networksecurity v0.11.0 cloud.google.com/go/orgpolicy v1.15.1 @@ -60,14 +60,14 @@ require ( github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.17 github.com/aws/aws-sdk-go-v2/service/apigateway v1.38.4 github.com/aws/aws-sdk-go-v2/service/autoscaling v1.64.0 - github.com/aws/aws-sdk-go-v2/service/cloudfront v1.59.0 + github.com/aws/aws-sdk-go-v2/service/cloudfront v1.60.0 github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.53.1 github.com/aws/aws-sdk-go-v2/service/directconnect v1.38.11 - github.com/aws/aws-sdk-go-v2/service/dynamodb v1.54.0 + github.com/aws/aws-sdk-go-v2/service/dynamodb v1.55.0 github.com/aws/aws-sdk-go-v2/service/ec2 v1.285.0 github.com/aws/aws-sdk-go-v2/service/ecs v1.71.0 github.com/aws/aws-sdk-go-v2/service/efs v1.41.10 - github.com/aws/aws-sdk-go-v2/service/eks v1.77.0 + github.com/aws/aws-sdk-go-v2/service/eks v1.77.1 github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing v1.33.19 github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.54.6 github.com/aws/aws-sdk-go-v2/service/iam v1.53.2 @@ -92,11 +92,11 @@ require ( github.com/go-jose/go-jose/v4 v4.1.3 github.com/google/btree v1.1.3 github.com/google/uuid v1.6.0 - github.com/googleapis/gax-go/v2 v2.16.0 + github.com/googleapis/gax-go/v2 v2.17.0 github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e github.com/hashicorp/go-retryablehttp v0.7.8 github.com/hashicorp/hcl/v2 v2.24.0 - github.com/hashicorp/terraform-config-inspect v0.0.0-20260120201749-785479628bd7 + github.com/hashicorp/terraform-config-inspect v0.0.0-20260204111900-477360eb0c77 github.com/jedib0t/go-pretty/v6 v6.7.8 github.com/micahhausler/aws-iam-policy v0.4.2 github.com/miekg/dns v1.1.72 @@ -105,7 +105,7 @@ require ( github.com/nats-io/jwt/v2 v2.8.0 github.com/nats-io/nats-server/v2 v2.12.4 github.com/nats-io/nats.go v1.48.0 - github.com/nats-io/nkeys v0.4.12 + github.com/nats-io/nkeys v0.4.15 github.com/onsi/ginkgo/v2 v2.28.1 // indirect github.com/onsi/gomega v1.39.1 // indirect github.com/openrdap/rdap v0.9.2-0.20240517203139-eb57b3a8dedd @@ -123,13 +123,13 @@ require ( github.com/zclconf/go-cty v1.17.0 go.etcd.io/bbolt v1.4.3 go.opentelemetry.io/contrib/detectors/aws/ec2/v2 v2.0.0-20250901115419-474a7992e57c - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 - go.opentelemetry.io/otel v1.39.0 - go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0 - go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.39.0 - go.opentelemetry.io/otel/sdk v1.39.0 - go.opentelemetry.io/otel/trace v1.39.0 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 + go.opentelemetry.io/otel v1.40.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.40.0 + go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.40.0 + go.opentelemetry.io/otel/sdk v1.40.0 + go.opentelemetry.io/otel/trace v1.40.0 go.uber.org/automaxprocs v1.6.0 go.uber.org/goleak v1.3.0 go.uber.org/mock v0.6.0 @@ -138,8 +138,9 @@ require ( golang.org/x/sync v0.19.0 golang.org/x/text v0.33.0 gonum.org/v1/gonum v0.17.0 - google.golang.org/api v0.264.0 - google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 + google.golang.org/api v0.265.0 + google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 google.golang.org/grpc v1.78.0 google.golang.org/protobuf v1.36.11 gopkg.in/ini.v1 v1.67.1 @@ -216,7 +217,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect github.com/gookit/color v1.5.4 // indirect github.com/gorilla/css v1.0.1 // indirect - github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -262,7 +263,7 @@ require ( go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect go.opentelemetry.io/otel/log v0.11.0 // indirect - go.opentelemetry.io/otel/metric v1.39.0 // indirect + go.opentelemetry.io/otel/metric v1.40.0 // indirect go.opentelemetry.io/otel/schema v0.0.12 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect @@ -276,8 +277,7 @@ require ( golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.41.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect - google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20260122232226-8e98ce8d340d // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect gopkg.in/go-jose/go-jose.v2 v2.6.3 // indirect gopkg.in/inf.v0 v0.9.1 // indirect @@ -285,6 +285,6 @@ require ( k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 // indirect sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect sigs.k8s.io/randfill v1.0.0 // indirect - sigs.k8s.io/structured-merge-diff/v6 v6.3.0 // indirect + sigs.k8s.io/structured-merge-diff/v6 v6.3.2-0.20260122202528-d9cc6641c482 // indirect sigs.k8s.io/yaml v1.6.0 // indirect ) diff --git a/go.sum b/go.sum index 6abb810c..15ea0ab2 100644 --- a/go.sum +++ b/go.sum @@ -16,16 +16,16 @@ cel.dev/expr v0.24.0 h1:56OvJKSH3hDGL0ml5uSxZmz3/3Pq4tJ+fb1unVLAFcY= cel.dev/expr v0.24.0/go.mod h1:hLPLo1W4QUmuYdA72RBX06QTs6MXw941piREPl3Yfiw= cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= -cloud.google.com/go/aiplatform v1.114.0 h1:TCrSLci+NFEAx0PZMv8btGe5j68RivArmDJbBLIc/3o= -cloud.google.com/go/aiplatform v1.114.0/go.mod h1:W5yMrpIuHG/CSK8iF7XnwIfCJu6dcLRQ0cTqGR5vwwE= +cloud.google.com/go/aiplatform v1.115.0 h1:m/dIJ/HixZDvHoXBGkA5Sd0RbiQp5lBVyddvR9uxHqI= +cloud.google.com/go/aiplatform v1.115.0/go.mod h1:DwPJAxebOTy6BajSMjF7ah3QvlYO4jf2gpJw6/1z9gU= cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs= cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= -cloud.google.com/go/bigquery v1.72.0 h1:D/yLju+3Ens2IXx7ou1DJ62juBm+/coBInn4VVOg5Cw= -cloud.google.com/go/bigquery v1.72.0/go.mod h1:GUbRtmeCckOE85endLherHD9RsujY+gS7i++c1CqssQ= -cloud.google.com/go/bigtable v1.41.0 h1:99KOWShm/MUyuIbXBeVscdWJFV7GdgiYwFUrB5Iu4BI= -cloud.google.com/go/bigtable v1.41.0/go.mod h1:JlaltP06LEFXaxQdZiarGR9tKsX/II0IkNAKMDrWspI= +cloud.google.com/go/bigquery v1.73.1 h1:v//GZwdhtmCbZ87rOnxz7pectOGFS1GNRvrGTvLzka4= +cloud.google.com/go/bigquery v1.73.1/go.mod h1:KSLx1mKP/yGiA8U+ohSrqZM1WknUnjZAxHAQZ51/b1k= +cloud.google.com/go/bigtable v1.42.0 h1:SREvT4jLhJQZXUjsLmFs/1SMQJ+rKEj1cJuPE9liQs8= +cloud.google.com/go/bigtable v1.42.0/go.mod h1:oZ30nofVB6/UYGg7lBwGLWSea7NZUvw/WvBBgLY07xU= cloud.google.com/go/compute v1.54.0 h1:4CKmnpO+40z44bKG5bdcKxQ7ocNpRtOc9SCLLUzze1w= cloud.google.com/go/compute v1.54.0/go.mod h1:RfBj0L1x/pIM84BrzNX2V21oEv16EKRPBiTcBRRH1Ww= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= @@ -46,8 +46,8 @@ cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc= cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU= cloud.google.com/go/kms v1.25.0 h1:gVqvGGUmz0nYCmtoxWmdc1wli2L1apgP8U4fghPGSbQ= cloud.google.com/go/kms v1.25.0/go.mod h1:XIdHkzfj0bUO3E+LvwPg+oc7s58/Ns8Nd8Sdtljihbk= -cloud.google.com/go/logging v1.13.1 h1:O7LvmO0kGLaHY/gq8cV7T0dyp6zJhYAOtZPX4TF3QtY= -cloud.google.com/go/logging v1.13.1/go.mod h1:XAQkfkMBxQRjQek96WLPNze7vsOmay9H5PqfsNYDqvw= +cloud.google.com/go/logging v1.13.2 h1:qqlHCBvieJT9Cdq4QqYx1KPadCQ2noD4FK02eNqHAjA= +cloud.google.com/go/logging v1.13.2/go.mod h1:zaybliM3yun1J8mU2dVQ1/qDzjbOqEijZCn6hSBtKak= cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8= cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk= cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE= @@ -68,8 +68,8 @@ cloud.google.com/go/securitycentermanagement v1.1.6 h1:XFqjKq4ZpKTj8xCXWs/mTmh/U cloud.google.com/go/securitycentermanagement v1.1.6/go.mod h1:nt5Z6rU4s2/j8R/EQxG5K7OfVAfAfwo89j0Nx2Srzaw= cloud.google.com/go/spanner v1.87.0 h1:M9RGcj/4gJk6yY1lRLOz1Ze+5ufoWhbIiurzXLOOfcw= cloud.google.com/go/spanner v1.87.0/go.mod h1:tcj735Y2aqphB6/l+X5MmwG4NnV+X1NJIbFSZGaHYXw= -cloud.google.com/go/storage v1.56.0 h1:iixmq2Fse2tqxMbWhLWC9HfBj1qdxqAmiK8/eqtsLxI= -cloud.google.com/go/storage v1.56.0/go.mod h1:Tpuj6t4NweCLzlNbw9Z9iwxEkrSem20AetIeH/shgVU= +cloud.google.com/go/storage v1.59.0 h1:9p3yDzEN9Vet4JnbN90FECIw6n4FCXcKBK1scxtQnw8= +cloud.google.com/go/storage v1.59.0/go.mod h1:cMWbtM+anpC74gn6qjLh+exqYcfmB9Hqe5z6adx+CLI= cloud.google.com/go/storagetransfer v1.13.1 h1:Sjukr1LtUt7vLTHNvGc2gaAqlXNFeDFRIRmWGrFaJlY= cloud.google.com/go/storagetransfer v1.13.1/go.mod h1:S858w5l383ffkdqAqrAA+BC7KlhCqeNieK3sFf5Bj4Y= connectrpc.com/connect v1.18.1 h1:PAg7CjSAGvscaf6YZKUefjoih5Z/qYkyaTrBW8xvYPw= @@ -182,22 +182,22 @@ github.com/aws/aws-sdk-go-v2/service/apigateway v1.38.4 h1:V8gcFwJPP3eXZXpeui+p9 github.com/aws/aws-sdk-go-v2/service/apigateway v1.38.4/go.mod h1:iJF5UdwkFue/YuUGCFsCCdT3SBMUx0s+h5TNi0Sz+qg= github.com/aws/aws-sdk-go-v2/service/autoscaling v1.64.0 h1:s92jPptCu97RNwU1yF3jD4ahLZrQ0QkUIvrn464rQ2A= github.com/aws/aws-sdk-go-v2/service/autoscaling v1.64.0/go.mod h1:8O5Pj92iNpfw/Fa7WdHbn6YiEjDoVdutz+9PGRNoP3Y= -github.com/aws/aws-sdk-go-v2/service/cloudfront v1.59.0 h1:evSZnlPGyDgStAmjLK9LcSoLvEk3oSUyJz4KIFfzJEs= -github.com/aws/aws-sdk-go-v2/service/cloudfront v1.59.0/go.mod h1:9Hd/cqshF4zl13KGLkWtRfITbvKR6m6FZHwhL2BYDSY= +github.com/aws/aws-sdk-go-v2/service/cloudfront v1.60.0 h1:RUQqU9L1LnFJ+9t5hsSB7GI6dVvJDCnG4WgRlDeHK6E= +github.com/aws/aws-sdk-go-v2/service/cloudfront v1.60.0/go.mod h1:9Hd/cqshF4zl13KGLkWtRfITbvKR6m6FZHwhL2BYDSY= github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.53.1 h1:ElB5x0nrBHgQs+XcpQ1XJpSJzMFCq6fDTpT6WQCWOtQ= github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.53.1/go.mod h1:Cj+LUEvAU073qB2jInKV6Y0nvHX0k7bL7KAga9zZ3jw= github.com/aws/aws-sdk-go-v2/service/directconnect v1.38.11 h1:3+DkKJAq5VVqPNu3eT6j0UchZDjDsNeqFNAqsomMPDc= github.com/aws/aws-sdk-go-v2/service/directconnect v1.38.11/go.mod h1:DNG3VkdVy874VMHH46ekGsD3nq6D4tyDV3HIOuVoouM= -github.com/aws/aws-sdk-go-v2/service/dynamodb v1.54.0 h1:SW3MUVGaqOv/h4spv3IubyGz9CpvE0gHWEJsZQNPFMs= -github.com/aws/aws-sdk-go-v2/service/dynamodb v1.54.0/go.mod h1:ctEsEHY2vFQc6i4KU07q4n68v7BAmTbujv2Y+z8+hQY= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.55.0 h1:CyYoeHWjVSGimzMhlL0Z4l5gLCa++ccnRJKrsaNssxE= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.55.0/go.mod h1:ctEsEHY2vFQc6i4KU07q4n68v7BAmTbujv2Y+z8+hQY= github.com/aws/aws-sdk-go-v2/service/ec2 v1.285.0 h1:cRZQsqCy59DSJmvmUYzi9K+dutysXzfx6F+fkcIHtOk= github.com/aws/aws-sdk-go-v2/service/ec2 v1.285.0/go.mod h1:Uy+C+Sc58jozdoL1McQr8bDsEvNFx+/nBY+vpO1HVUY= github.com/aws/aws-sdk-go-v2/service/ecs v1.71.0 h1:MzP/ElwTpINq+hS80ZQz4epKVnUTlz8Sz+P/AFORCKM= github.com/aws/aws-sdk-go-v2/service/ecs v1.71.0/go.mod h1:pMlGFDpHoLTJOIZHGdJOAWmi+xeIlQXuFTuQxs1epYE= github.com/aws/aws-sdk-go-v2/service/efs v1.41.10 h1:7ixaaFyZ8xXJWPcK3qQKFf1k1HgME9rtCY7S6Unih8I= github.com/aws/aws-sdk-go-v2/service/efs v1.41.10/go.mod h1:QwCUd/L5/HX4s/uWt3LPEOwQb/AYE4OyMGB8SL9/W4Y= -github.com/aws/aws-sdk-go-v2/service/eks v1.77.0 h1:Z5mTpmbJKU7jEM7xoXI5tO4Nm0JUZSgVSFkpYuu6Ic0= -github.com/aws/aws-sdk-go-v2/service/eks v1.77.0/go.mod h1:Qg678m+87sCuJhcsZojenz8mblYG+Tq86V4m3hjVz0s= +github.com/aws/aws-sdk-go-v2/service/eks v1.77.1 h1:pMXNbXUX4Xd9fRmRdEe/vQ/5EFRy2M4jvW6geO5lhd8= +github.com/aws/aws-sdk-go-v2/service/eks v1.77.1/go.mod h1:Qg678m+87sCuJhcsZojenz8mblYG+Tq86V4m3hjVz0s= github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing v1.33.19 h1:ybEda2mkkX2o8NadXZBtcO9tgmW9cTQgeVSjypNsAy0= github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing v1.33.19/go.mod h1:RiMytGvN4azx4yLM0Kn3bX/XO9dLxj+eG72Smy+vNzI= github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.54.6 h1:fQR1aeZKaiPkNPya0JMy2nhsoqoSgIWc3/QTiTiL1K0= @@ -361,8 +361,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao= github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8= -github.com/googleapis/gax-go/v2 v2.16.0 h1:iHbQmKLLZrexmb0OSsNGTeSTS0HO4YvFOG8g5E4Zd0Y= -github.com/googleapis/gax-go/v2 v2.16.0/go.mod h1:o1vfQjjNZn4+dPnRdl/4ZD7S9414Y4xA+a/6Icj6l14= +github.com/googleapis/gax-go/v2 v2.17.0 h1:RksgfBpxqff0EZkDWYuz9q/uWsTVz+kf43LsZ1J6SMc= +github.com/googleapis/gax-go/v2 v2.17.0/go.mod h1:mzaqghpQp4JDh3HvADwrat+6M3MOIDp5YKHhb9PAgDY= github.com/gookit/color v1.4.2/go.mod h1:fqRyamkC1W8uxl+lxCQxOT09l/vYfZ+QeiX3rKQHCoQ= github.com/gookit/color v1.5.0/go.mod h1:43aQb+Zerm/BWh2GnrgOQm7ffz7tvQXEKV6BFMl7wAo= github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= @@ -371,8 +371,8 @@ github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e h1:XmA6L9IP github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e/go.mod h1:AFIo+02s+12CEg8Gzz9kzhCbmbq6JcKNrhHffCGA9z4= github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 h1:NmZ1PKzSTQbuGHw9DGPFomqkkLWMC+vZCkfs+FHv1Vg= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3/go.mod h1:zQrxl1YP88HQlA6i9c63DSVPFklWpGX4OWAc9bFuaH4= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7 h1:X+2YciYSxvMQK0UZ7sg45ZVabVZBeBuvMkmuI2V3Fak= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.7/go.mod h1:lW34nIZuQ8UDPdkon5fmfp2l3+ZkQ2me/+oecHYLOII= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= @@ -383,8 +383,8 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/hashicorp/hcl/v2 v2.24.0 h1:2QJdZ454DSsYGoaE6QheQZjtKZSUs9Nh2izTWiwQxvE= github.com/hashicorp/hcl/v2 v2.24.0/go.mod h1:oGoO1FIQYfn/AgyOhlg9qLC6/nOJPX3qGbkZpYAcqfM= -github.com/hashicorp/terraform-config-inspect v0.0.0-20260120201749-785479628bd7 h1:3roJG2qA6gqvm3O89wCtlIRnw2el75cC6A9t1akIZ9I= -github.com/hashicorp/terraform-config-inspect v0.0.0-20260120201749-785479628bd7/go.mod h1:Gz/z9Hbn+4KSp8A2FBtNszfLSdT2Tn/uAKGuVqqWmDI= +github.com/hashicorp/terraform-config-inspect v0.0.0-20260204111900-477360eb0c77 h1:JyCyXTn0iSHO66Gy5D+4Q031oqRBSRrARILrc1NFu2U= +github.com/hashicorp/terraform-config-inspect v0.0.0-20260204111900-477360eb0c77/go.mod h1:Gz/z9Hbn+4KSp8A2FBtNszfLSdT2Tn/uAKGuVqqWmDI= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -461,8 +461,8 @@ github.com/nats-io/nats-server/v2 v2.12.4 h1:ZnT10v2LU2Xcoiy8ek9X6Se4YG8EuMfIfvA github.com/nats-io/nats-server/v2 v2.12.4/go.mod h1:5MCp/pqm5SEfsvVZ31ll1088ZTwEUdvRX1Hmh/mTTDg= github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U= github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g= -github.com/nats-io/nkeys v0.4.12 h1:nssm7JKOG9/x4J8II47VWCL1Ds29avyiQDRn0ckMvDc= -github.com/nats-io/nkeys v0.4.12/go.mod h1:MT59A1HYcjIcyQDJStTfaOY6vhy9XTUjOFo+SVsvpBg= +github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4= +github.com/nats-io/nkeys v0.4.15/go.mod h1:CpMchTXC9fxA5zrMo4KpySxNjiDVvr8ANOSZdiNfUrs= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/onsi/ginkgo/v2 v2.28.1 h1:S4hj+HbZp40fNKuLUQOYLDgZLwNUVn19N3Atb98NCyI= @@ -593,28 +593,28 @@ go.opentelemetry.io/contrib/detectors/gcp v1.38.0 h1:ZoYbqX7OaA/TAikspPl3ozPI6iY go.opentelemetry.io/contrib/detectors/gcp v1.38.0/go.mod h1:SU+iU7nu5ud4oCb3LQOhIZ3nRLj6FNVrKgtflbaf2ts= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ= -go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= -go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 h1:f0cb2XPmrqn4XMy9PNliTgRKJgS5WcL/u0/WRYGz4t0= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0/go.mod h1:vnakAaFckOMiMtOIhFI2MNH4FYrZzXCYxmb1LlhoGz8= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0 h1:Ckwye2FpXkYgiHX7fyVrN1uA/UYd9ounqqTuSNAv0k4= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.39.0/go.mod h1:teIFJh5pW2y+AN7riv6IBPX2DuesS3HgP39mwOspKwU= -go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.39.0 h1:8UPA4IbVZxpsD76ihGOQiFml99GPAEZLohDXvqHdi6U= -go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.39.0/go.mod h1:MZ1T/+51uIVKlRzGw1Fo46KEWThjlCBZKl2LzY5nv4g= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 h1:7iP2uCb7sGddAr30RRS6xjKy7AZ2JtTOPA3oolgVSw8= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0/go.mod h1:c7hN3ddxs/z6q9xwvfLPk+UHlWRQyaeR1LdgfL/66l0= +go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms= +go.opentelemetry.io/otel v1.40.0/go.mod h1:IMb+uXZUKkMXdPddhwAHm6UfOwJyh4ct1ybIlV14J0g= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0 h1:QKdN8ly8zEMrByybbQgv8cWBcdAarwmIPZ6FThrWXJs= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.40.0/go.mod h1:bTdK1nhqF76qiPoCCdyFIV+N/sRHYXYCTQc+3VCi3MI= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.40.0 h1:wVZXIWjQSeSmMoxF74LzAnpVQOAFDo3pPji9Y4SOFKc= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.40.0/go.mod h1:khvBS2IggMFNwZK/6lEeHg/W57h/IX6J4URh57fuI40= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.40.0 h1:MzfofMZN8ulNqobCmCAVbqVL5syHw+eB2qPRkCMA/fQ= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.40.0/go.mod h1:E73G9UFtKRXrxhBsHtG00TB5WxX57lpsQzogDkqBTz8= go.opentelemetry.io/otel/log v0.11.0 h1:c24Hrlk5WJ8JWcwbQxdBqxZdOK7PcP/LFtOtwpDTe3Y= go.opentelemetry.io/otel/log v0.11.0/go.mod h1:U/sxQ83FPmT29trrifhQg+Zj2lo1/IPN1PF6RTFqdwc= -go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= -go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/metric v1.40.0 h1:rcZe317KPftE2rstWIBitCdVp89A2HqjkxR3c11+p9g= +go.opentelemetry.io/otel/metric v1.40.0/go.mod h1:ib/crwQH7N3r5kfiBZQbwrTge743UDc7DTFVZrrXnqc= go.opentelemetry.io/otel/schema v0.0.12 h1:X8NKrwH07Oe9SJruY/D1XmwHrb6D2+qrLs2POlZX7F4= go.opentelemetry.io/otel/schema v0.0.12/go.mod h1:+w+Q7DdGfykSNi+UU9GAQz5/rtYND6FkBJUWUXzZb0M= -go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= -go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= -go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= -go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= -go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= -go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +go.opentelemetry.io/otel/sdk v1.40.0 h1:KHW/jUzgo6wsPh9At46+h4upjtccTmuZCFAc9OJ71f8= +go.opentelemetry.io/otel/sdk v1.40.0/go.mod h1:Ph7EFdYvxq72Y8Li9q8KebuYUr2KoeyHx0DRMKrYBUE= +go.opentelemetry.io/otel/sdk/metric v1.40.0 h1:mtmdVqgQkeRxHgRv4qhyJduP3fYJRMX4AtAlbuWdCYw= +go.opentelemetry.io/otel/sdk/metric v1.40.0/go.mod h1:4Z2bGMf0KSK3uRjlczMOeMhKU2rhUqdWNoKcYrtcBPg= +go.opentelemetry.io/otel/trace v1.40.0 h1:WA4etStDttCSYuhwvEa8OP8I5EWu24lkOzp+ZYblVjw= +go.opentelemetry.io/otel/trace v1.40.0/go.mod h1:zeAhriXecNGP/s2SEG3+Y8X9ujcJOTqQ5RgdEJcawiA= go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= @@ -694,14 +694,14 @@ golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhS golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= -google.golang.org/api v0.264.0 h1:+Fo3DQXBK8gLdf8rFZ3uLu39JpOnhvzJrLMQSoSYZJM= -google.golang.org/api v0.264.0/go.mod h1:fAU1xtNNisHgOF5JooAs8rRaTkl2rT3uaoNGo9NS3R8= +google.golang.org/api v0.265.0 h1:FZvfUdI8nfmuNrE34aOWFPmLC+qRBEiNm3JdivTvAAU= +google.golang.org/api v0.265.0/go.mod h1:uAvfEl3SLUj/7n6k+lJutcswVojHPp2Sp08jWCu8hLY= google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 h1:VQZ/yAbAtjkHgH80teYd2em3xtIkkHd7ZhqfH2N9CsM= google.golang.org/genproto v0.0.0-20260128011058-8636f8732409/go.mod h1:rxKD3IEILWEu3P44seeNOAwZN4SaoKaQ/2eTg4mM6EM= -google.golang.org/genproto/googleapis/api v0.0.0-20260122232226-8e98ce8d340d h1:tUKoKfdZnSjTf5LW7xpG4c6SZ3Ozisn5eumcoTuMEN4= -google.golang.org/genproto/googleapis/api v0.0.0-20260122232226-8e98ce8d340d/go.mod h1:p3MLuOwURrGBRoEyFHBT3GjUwaCQVKeNqqWxlcISGdw= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 h1:H86B94AW+VfJWDqFeEbBPhEtHzJwJfTbgE2lZa54ZAQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 h1:merA0rdPeUV3YIIfHHcH4qBkiQAc1nfCKSI7lB4cV2M= +google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409/go.mod h1:fl8J1IvUjCilwZzQowmw2b7HQB2eAuYBabMXzWurF+I= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 h1:Jr5R2J6F6qWyzINc+4AM8t5pfUz6beZpHp678GNrMbE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= @@ -743,7 +743,7 @@ sigs.k8s.io/kind v0.31.0 h1:UcT4nzm+YM7YEbqiAKECk+b6dsvc/HRZZu9U0FolL1g= sigs.k8s.io/kind v0.31.0/go.mod h1:FSqriGaoTPruiXWfRnUXNykF8r2t+fHtK0P0m1AbGF8= sigs.k8s.io/randfill v1.0.0 h1:JfjMILfT8A6RbawdsK2JXGBR5AQVfd+9TbzrlneTyrU= sigs.k8s.io/randfill v1.0.0/go.mod h1:XeLlZ/jmk4i1HRopwe7/aU3H5n1zNUcX6TM94b3QxOY= -sigs.k8s.io/structured-merge-diff/v6 v6.3.0 h1:jTijUJbW353oVOd9oTlifJqOGEkUw2jB/fXCbTiQEco= -sigs.k8s.io/structured-merge-diff/v6 v6.3.0/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE= +sigs.k8s.io/structured-merge-diff/v6 v6.3.2-0.20260122202528-d9cc6641c482 h1:2WOzJpHUBVrrkDjU4KBT8n5LDcj824eX0I5UKcgeRUs= +sigs.k8s.io/structured-merge-diff/v6 v6.3.2-0.20260122202528-d9cc6641c482/go.mod h1:M3W8sfWvn2HhQDIbGWj3S099YozAsymCo/wrT5ohRUE= sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs= sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4= diff --git a/k8s-source/adapters/endpoints.go b/k8s-source/adapters/endpoints.go index 75ee19d5..556a9b74 100644 --- a/k8s-source/adapters/endpoints.go +++ b/k8s-source/adapters/endpoints.go @@ -100,6 +100,7 @@ func newEndpointsAdapter(cs *kubernetes.Clientset, cluster string, namespaces [] }, LinkedItemQueryExtractor: EndpointsExtractor, AdapterMetadata: endpointsAdapterMetadata, + cache: cache, } } diff --git a/k8s-source/adapters/generic_source.go b/k8s-source/adapters/generic_source.go index 0f1c0434..5ca362a8 100644 --- a/k8s-source/adapters/generic_source.go +++ b/k8s-source/adapters/generic_source.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "sync" "time" "github.com/overmindtech/cli/sdp-go" @@ -87,21 +86,6 @@ func (s *KubeTypeAdapter[Resource, ResourceList]) cacheDuration() time.Duration return s.CacheDuration } -var ( - noOpCacheK8sOnce sync.Once - noOpCacheK8s sdpcache.Cache -) - -func (s *KubeTypeAdapter[Resource, ResourceList]) Cache() sdpcache.Cache { - if s.cache == nil { - noOpCacheK8sOnce.Do(func() { - noOpCacheK8s = sdpcache.NewNoOpCache() - }) - return noOpCacheK8s - } - return s.cache -} - // validate Validates that the adapter is correctly set up func (s *KubeTypeAdapter[Resource, ResourceList]) Validate() error { if s.NamespacedInterfaceBuilder == nil && s.ClusterInterfaceBuilder == nil { @@ -178,7 +162,7 @@ func (s *KubeTypeAdapter[Resource, ResourceList]) Get(ctx context.Context, scope var qErr *sdp.QueryError var done func() - cacheHit, ck, cachedItems, qErr, done = s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.Type(), query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done = s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.Type(), query, ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -197,7 +181,7 @@ func (s *KubeTypeAdapter[Resource, ResourceList]) Get(ctx context.Context, scope ErrorType: sdp.QueryError_NOSCOPE, ErrorString: err.Error(), } - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) return nil, err } @@ -211,16 +195,16 @@ func (s *KubeTypeAdapter[Resource, ResourceList]) Get(ctx context.Context, scope ErrorString: statusErr.ErrStatus.Message, } } - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) return nil, err } item, err := s.resourceToItem(resource) if err != nil { - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) return nil, err } - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) return item, nil } @@ -231,7 +215,7 @@ func (s *KubeTypeAdapter[Resource, ResourceList]) List(ctx context.Context, scop var qErr *sdp.QueryError var done func() - cacheHit, ck, cachedItems, qErr, done = s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.Type(), "", ignoreCache) + cacheHit, ck, cachedItems, qErr, done = s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_LIST, scope, s.Type(), "", ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -242,12 +226,12 @@ func (s *KubeTypeAdapter[Resource, ResourceList]) List(ctx context.Context, scop items, err := s.listWithOptions(ctx, scope, metav1.ListOptions{}) if err != nil { - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) return nil, err } for _, item := range items { - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) } return items, nil @@ -291,12 +275,12 @@ func (s *KubeTypeAdapter[Resource, ResourceList]) Search(ctx context.Context, sc items, err := s.listWithOptions(ctx, scope, opts) if err != nil { - s.Cache().StoreError(ctx, err, s.cacheDuration(), ck) + s.cache.StoreError(ctx, err, s.cacheDuration(), ck) return nil, err } for _, item := range items { - s.Cache().StoreItem(ctx, item, s.cacheDuration(), ck) + s.cache.StoreItem(ctx, item, s.cacheDuration(), ck) } return items, nil diff --git a/k8s-source/adapters/generic_source_test.go b/k8s-source/adapters/generic_source_test.go index c895760a..7fcf5eb8 100644 --- a/k8s-source/adapters/generic_source_test.go +++ b/k8s-source/adapters/generic_source_test.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" @@ -179,6 +180,7 @@ func createAdapter(namespaced bool) *KubeTypeAdapter[*v1.Pod, *v1.PodList] { TypeName: "Pod", ClusterName: "minikube", Namespaces: []string{"default", "app1"}, + cache: sdpcache.NewNoOpCache(), } } diff --git a/k8s-source/adapters/shared_test.go b/k8s-source/adapters/shared_test.go index e6d60443..7241473b 100644 --- a/k8s-source/adapters/shared_test.go +++ b/k8s-source/adapters/shared_test.go @@ -210,11 +210,10 @@ func TestMain(m *testing.M) { // log.Println("🎁 Creating resources in cluster for testing") // err = CurrentCluster.ApplyBaselineConfig() - - if err != nil { - log.Fatal(err) - os.Exit(1) - } + // if err != nil { + // log.Fatal(err) + // os.Exit(1) + // } log.Println("✅ Running tests") code := m.Run() diff --git a/k8s-source/cmd/root.go b/k8s-source/cmd/root.go index 1645c4ee..9a93677f 100644 --- a/k8s-source/cmd/root.go +++ b/k8s-source/cmd/root.go @@ -18,6 +18,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/k8s-source/adapters" + "github.com/overmindtech/cli/k8s-source/proc" "github.com/overmindtech/cli/logging" "github.com/overmindtech/cli/sdpcache" "github.com/overmindtech/cli/tracing" @@ -39,361 +40,318 @@ var cfgFile string // rootCmd represents the base command when called without any subcommands var rootCmd = &cobra.Command{ - Use: "k8s-source", - Short: "Kubernetes source", + Use: "k8s-source", + Short: "Kubernetes source", + SilenceUsage: true, Long: `Gathers details from existing kubernetes clusters `, - Run: func(cmd *cobra.Command, args []string) { - exitcode := run(cmd, args) - os.Exit(exitcode) - }, -} - -func run(_ *cobra.Command, _ []string) int { - kubeconfig := viper.GetString("kubeconfig") - // get engine config - engineConfig, err := discovery.EngineConfigFromViper("k8s", tracing.Version()) - if err != nil { - log.WithError(err).Fatal("Could not get engine config from viper") - } + RunE: func(cmd *cobra.Command, args []string) error { + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + defer tracing.LogRecoverToReturn(ctx, "k8s-source.root") - log.WithFields(log.Fields{ - "kubeconfig": kubeconfig, - }).Info("Got config") - - var clientSet *kubernetes.Clientset - var restConfig *rest.Config - - if kubeconfig == "" { - log.Info("Using in-cluster config") - - restConfig, err = rest.InClusterConfig() + // get engine config + engineConfig, err := discovery.EngineConfigFromViper("k8s", tracing.Version()) if err != nil { - sentry.CaptureException(err) - log.WithError(err).Error("Could not load in-cluster config") - - return 1 + log.WithError(err).Error("Could not get engine config from viper") + return fmt.Errorf("could not get engine config from viper: %w", err) } - } else { - // Load kubernetes config from a file - restConfig, err = clientcmd.BuildConfigFromFlags("", kubeconfig) - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Error("Could not load kubernetes config") - return 1 + // Best-effort: derive cluster-specific NATS queue name before Start(). + // This loads the kubeconfig just to hash the rest config string for the + // queue name. If it fails (e.g. in-cluster config not yet available), + // we continue with the default queue name — the underlying error will + // surface again after Start() via SetInitError. + if restCfg, loadErr := loadRestConfig(viper.GetString("kubeconfig")); loadErr == nil { + configHash := fmt.Sprintf("%x", sha256.Sum256([]byte(restCfg.String()))) + engineConfig.NATSQueueName = fmt.Sprintf("k8s-source-%v", configHash) } - } - - restConfig.Wrap(func(rt http.RoundTripper) http.RoundTripper { return otelhttp.NewTransport(rt) }) - // Set up rate limiting - restConfig.RateLimiter = flowcontrol.NewTokenBucketRateLimiter( - float32(viper.GetFloat64("rate-limit-qps")), - viper.GetInt("rate-limit-burst"), - ) - // Create clientSet - clientSet, err = kubernetes.NewForConfig(restConfig) - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Error("Could not create kubernetes client") - return 1 - } - // - // Discover info - // - // Now that we have a connection to the kubernetes cluster we need to go - // about generating some adapters. - var k8sURL *url.URL - - k8sURL, err = url.Parse(restConfig.Host) - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Errorf("Could not parse kubernetes url: %v", restConfig.Host) - - return 1 - } - - // Calculate the SHA-1 hash of the config to use as the queue name. This - // means that adapters with the same config will be in the same queue. - // Note that the config object implements redaction in the String() - // method so we don't have to worry about leaking secrets - configHash := fmt.Sprintf("%x", sha256.Sum256([]byte(restConfig.String()))) - engineConfig.NATSQueueName = fmt.Sprintf("k8s-source-%v", configHash) - - // If there is no port then set one - if k8sURL.Port() == "" { - switch k8sURL.Scheme { - case "http": - k8sURL.Host = k8sURL.Host + ":80" - case "https": - k8sURL.Host = k8sURL.Host + ":443" + if engineConfig.HeartbeatOptions == nil { + engineConfig.HeartbeatOptions = &discovery.HeartbeatOptions{} } - } - err = engineConfig.CreateClients() - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Fatal("could not create auth clients") - } - - // Work out the cluster name - clusterName := viper.GetString("cluster-name") - if clusterName == "" { - clusterName = k8sURL.Host - } - if engineConfig.HeartbeatOptions == nil { - engineConfig.HeartbeatOptions = &discovery.HeartbeatOptions{} - } + e, err := discovery.NewEngine(engineConfig) + if err != nil { + sentry.CaptureException(err) + log.WithError(err).Error("Error initializing Engine") + return fmt.Errorf("error initializing engine: %w", err) + } - e, err := discovery.NewEngine(engineConfig) - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Error("Error initializing Engine") + // ReadinessCheck verifies adapters are healthy by using a Node adapter + // Timeout is handled by SendHeartbeat, HTTP handlers rely on request context + e.SetReadinessCheck(func(ctx context.Context) error { + // Find a Node adapter to verify adapter health + adapters := e.AdaptersByType("Node") + if len(adapters) == 0 { + return fmt.Errorf("readiness check failed: no Node adapters available") + } + // Use first adapter and try to list from first scope + adapter := adapters[0] + scopes := adapter.Scopes() + if len(scopes) == 0 { + return fmt.Errorf("readiness check failed: no scopes available for Node adapter") + } + listableAdapter, ok := adapter.(discovery.ListableAdapter) + if !ok { + return fmt.Errorf("readiness check failed: Node adapter is not listable") + } + _, err := listableAdapter.List(ctx, scopes[0], true) + if err != nil { + return fmt.Errorf("readiness check (listing nodes) failed: %w", err) + } + return nil + }) - return 1 - } + // Serve health probes before initialization so they're available even on failure + e.ServeHealthProbes(viper.GetInt("health-check-port")) - // ReadinessCheck verifies adapters are healthy by using a Node adapter - // Timeout is handled by SendHeartbeat, HTTP handlers rely on request context - e.SetReadinessCheck(func(ctx context.Context) error { - // Find a Node adapter to verify adapter health - adapters := e.AdaptersByType("Node") - if len(adapters) == 0 { - return fmt.Errorf("readiness check failed: no Node adapters available") - } - // Use first adapter and try to list from first scope - adapter := adapters[0] - scopes := adapter.Scopes() - if len(scopes) == 0 { - return fmt.Errorf("readiness check failed: no scopes available for Node adapter") - } - listableAdapter, ok := adapter.(discovery.ListableAdapter) - if !ok { - return fmt.Errorf("readiness check failed: Node adapter is not listable") - } - _, err := listableAdapter.List(ctx, scopes[0], true) + // Start the engine (NATS connection) before config validation so heartbeats work + err = e.Start(ctx) if err != nil { - return fmt.Errorf("readiness check (listing nodes) failed: %w", err) + sentry.CaptureException(err) + log.WithError(err).Error("Could not start engine") + return fmt.Errorf("could not start engine: %w", err) } - return nil - }) - // Start HTTP server for health checks - healthCheckPort := viper.GetInt("health-check-port") - e.ServeHealthProbes(healthCheckPort) + // Config validation and K8s client setup (permanent errors — SetInitError, stay running) + var loadAdapters func(ctx context.Context) error + reload := make(chan watch.Event, 1024) - // Create channels for interrupts - quit := make(chan os.Signal, 1) - signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) - restart := make(chan watch.Event, 1024) + k8sCfg, clientSet, clusterName, cfgErr := createK8sClient() + if cfgErr != nil { + log.WithError(cfgErr).Error("K8s source config error - pod will stay running with error status") + e.SetInitError(cfgErr) + sentry.CaptureException(cfgErr) + } else { + log.WithFields(log.Fields{ + "kubeconfig": k8sCfg.Kubeconfig, + "cluster-name": clusterName, + }).Info("Got config") + + // loadAdapters is the single-attempt adapter init function that lists + // namespaces, creates adapters, and adds them to the engine. + loadAdapters = func(ctx context.Context) error { + log.Info("Listing namespaces") + list, err := clientSet.CoreV1().Namespaces().List(ctx, metav1.ListOptions{}) + if err != nil { + return fmt.Errorf("could not list namespaces: %w", err) + } - // Get the initial starting point - list, err := clientSet.CoreV1().Namespaces().List(context.Background(), metav1.ListOptions{}) - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Error("could not list namespaces") + namespaces := make([]string, len(list.Items)) + for i := range list.Items { + namespaces[i] = list.Items[i].Name + } - return 1 - } + log.WithField("count", len(namespaces)).Info("Got namespaces") - // Watch namespaces from here - wi, err := clientSet.CoreV1().Namespaces().Watch(context.Background(), metav1.ListOptions{ - ResourceVersion: list.ResourceVersion, - }) - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Error("could not start watching namespaces") + // Create a shared cache for all adapters in this source + sharedCache := sdpcache.NewCache(ctx) - return 1 - } + // Create the adapter list + adapterList := adapters.LoadAllAdapters(clientSet, clusterName, namespaces, sharedCache) - watchCtx, watchCancel := context.WithCancel(context.Background()) - defer watchCancel() + // Add adapters to the engine + return e.AddAdapters(adapterList...) + } - go func() { - defer tracing.LogRecoverToReturn(watchCtx, "Namespace watch") + // Use InitialiseAdapters for the initial load (retries with backoff) + e.InitialiseAdapters(ctx, loadAdapters) - attempts := 0 - sleep := 1 * time.Second + // Set up namespace watch for dynamic restarts + watchCtx, watchCancel := context.WithCancel(ctx) + defer watchCancel() - for { - select { - case event, ok := <-wi.ResultChan(): - if !ok { - // When the channel is closed then we need to restart the - // watch. This happens regularly on EKS. - log.Debug("Namespace watch channel closed, re-subscribing") - - wi, err = watchNamespaces(watchCtx, clientSet) - // Check for transient network errors - if err != nil { - var netErr *net.OpError - if errors.As(err, &netErr) { - // Mark a failure - attempts++ - - // If we have had less than 3 failures then retry - if attempts < 4 { - // The watch interface will be nil if we - // couldn't connect, so create a fake watcher - // that is closed so that we end up in this loop - // again - wi = watch.NewFake() - wi.Stop() - - jitter := time.Duration(rand.Int63n(int64(sleep))) //nolint:gosec // we don't need cryptographically secure randomness here - sleep = sleep + jitter/2 - - log.WithError(err).Errorf("Transient network error, retrying in %v seconds", sleep.String()) - time.Sleep(sleep) - continue - } - } + go func() { + defer tracing.LogRecoverToReturn(watchCtx, "Namespace watch setup") - sentry.CaptureException(err) - log.WithError(err).Error("could not list namespaces") + // Wait briefly for initial adapter loading to complete or make progress + // before starting the namespace watch + wi, err := watchNamespaces(watchCtx, clientSet) + if err != nil { + watchErr := fmt.Errorf("could not start namespace watch: %w", err) + log.WithError(watchErr).Error("K8s namespace watch failed - pod will stay running with error status") + e.SetInitError(watchErr) + sentry.CaptureException(watchErr) + return + } - // Send a fatal event that will kill the main goroutine - restart <- watch.Event{ - Type: watch.EventType("FATAL"), - } + defer tracing.LogRecoverToReturn(watchCtx, "Namespace watch") + + attempts := 0 + sleep := 1 * time.Second + + for { + select { + case event, ok := <-wi.ResultChan(): + if !ok { + // When the channel is closed then we need to restart the + // watch. This happens regularly on EKS. + log.Debug("Namespace watch channel closed, re-subscribing") + + wi, err = watchNamespaces(watchCtx, clientSet) + // Check for transient network errors + if err != nil { + var netErr *net.OpError + if errors.As(err, &netErr) { + // Mark a failure + attempts++ + + // If we have had less than 3 failures then retry + if attempts < 4 { + // The watch interface will be nil if we + // couldn't connect, so create a fake watcher + // that is closed so that we end up in this loop + // again + wi = watch.NewFake() + wi.Stop() + + jitter := time.Duration(rand.Int63n(int64(sleep))) //nolint:gosec // we don't need cryptographically secure randomness here + sleep = sleep + jitter/2 + + log.WithError(err).WithField("retry_in", sleep.String()).Error("Transient network error, retrying") + time.Sleep(sleep) + continue + } + } + + sentry.CaptureException(err) + log.WithError(err).Error("could not resubscribe to namespace watch") + + // Send a fatal event + reload <- watch.Event{ + Type: watch.EventType("FATAL"), + } + + return + } + // If it's worked, reset the failure counter + attempts = 0 + } else { + // If a watch event is received then we need to reload adapters + reload <- event + } + case <-watchCtx.Done(): return } - - // If it's worked, reset the failure counter - attempts = 0 - } else { - // If a watch event is received then we need to restart the - // engine - restart <- event } - case <-watchCtx.Done(): - return - } - } - }() - - start := func() error { - ctx := context.Background() - - // Query all namespaces - log.Info("Listing namespaces") - list, err := clientSet.CoreV1().Namespaces().List(ctx, metav1.ListOptions{}) - if err != nil { - return err + }() } - namespaces := make([]string, len(list.Items)) - - for i := range list.Items { - namespaces[i] = list.Items[i].Name - } - - log.Infof("got %v namespaces", len(namespaces)) - - // Create a shared cache for all adapters in this source - sharedCache := sdpcache.NewCache(ctx) - - // Create the adapter list - adapterList := adapters.LoadAllAdapters(clientSet, clusterName, namespaces, sharedCache) + defer func() { + err := e.Stop() + if err != nil { + sentry.CaptureException(fmt.Errorf("could not stop engine: %w", err)) + log.WithError(err).Error("Could not stop engine") + } + }() - // Add adapters to the engine - err = e.AddAdapters(adapterList...) - if err != nil { - return err + for { + select { + case <-ctx.Done(): + log.Info("Stopping engine") + return nil + case event := <-reload: + switch event.Type { //nolint:exhaustive // we on purpose fall through to default + case "": + // Discard empty events. After a certain period kubernetes + // starts sending occasional empty events, I can't work out why, + // maybe it's to keep the connection open. Either way they don't + // represent anything and should be discarded + log.Debug("Discarding empty event") + case "FATAL": + // This is a custom event type from permanent watch failures + // Don't exit - store error and continue in degraded state + fatalErr := fmt.Errorf("permanent failure in namespace watch after retries") + log.WithError(fatalErr).Error("K8s namespace watch failed permanently - pod will stay running with error status") + e.SetInitError(fatalErr) + sentry.CaptureException(fatalErr) + case "MODIFIED": + log.Debug("Namespace modified, ignoring") + default: + // Namespace added/deleted: reload adapters + log.WithField("event_type", event.Type).Info("Namespace change detected, reloading adapters") + e.ClearAdapters() + if reloadErr := loadAdapters(ctx); reloadErr != nil { + initErr := fmt.Errorf("could not reload adapters after namespace change: %w", reloadErr) + log.WithError(initErr).Error("K8s source reload failed - pod will stay running with error status") + e.SetInitError(initErr) + sentry.CaptureException(initErr) + } else { + // Reload succeeded, clear any previous init error + e.SetInitError(nil) + log.Info("K8s source reloaded successfully") + } + } + } } + }, +} - // Start the engine - err = e.Start(ctx) - - return err +// Execute adds all child commands to the root command and sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() { + if err := rootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) } +} - stop := func() error { - // Stop the engine - err = e.Stop() - if err != nil { - return err - } - - // Clear the adapters - e.ClearAdapters() - - return nil +// loadRestConfig loads a Kubernetes rest.Config from the given kubeconfig path. +// If the path is empty, in-cluster config is used. +func loadRestConfig(kubeconfig string) (*rest.Config, error) { + if kubeconfig == "" { + return rest.InClusterConfig() } + return clientcmd.BuildConfigFromFlags("", kubeconfig) +} - // Start the service initially - err = start() +// createK8sClient validates the K8s source config from viper, creates a +// Kubernetes client, and determines the cluster name. All failures are +// permanent config errors that should be reported via SetInitError. +func createK8sClient() (*proc.K8sConfig, *kubernetes.Clientset, string, error) { + k8sCfg, err := proc.ConfigFromViper() if err != nil { - err = fmt.Errorf("Could not start engine: %w", err) - sentry.CaptureException(err) - log.WithError(err) + return nil, nil, "", err + } - return 1 + restConfig, err := loadRestConfig(k8sCfg.Kubeconfig) + if err != nil { + return nil, nil, "", fmt.Errorf("could not load kubernetes config: %w", err) } - defer func() { - err := stop() - if err != nil { - err = fmt.Errorf("Could not stop engine: %w", err) - sentry.CaptureException(err) - log.WithError(err) - } - }() - - for { - select { - case <-quit: - log.Info("Stopping engine") - - // Stopping will be handled by deferred stop() - - return 0 - case event := <-restart: - switch event.Type { //nolint:exhaustive // we on purpose fall through to default - case "": - // Discard empty events. After a certain period kubernetes - // starts sending occasional empty events, I can't work out why, - // maybe it's to keep the connection open. Either way they don't - // represent anything and should be discarded - log.Debug("Discarding empty event") - case "FATAL": - // This is a custom event type that should signal the main - // goroutine to exit - log.Error("Fatal error in watch goroutine") - return 1 - case "MODIFIED": - log.Debug("Namespace modified, ignoring") - default: - err = stop() - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Error("Could not stop engine") + restConfig.Wrap(func(rt http.RoundTripper) http.RoundTripper { return otelhttp.NewTransport(rt) }) + restConfig.RateLimiter = flowcontrol.NewTokenBucketRateLimiter( + float32(k8sCfg.RateLimitQPS), + k8sCfg.RateLimitBurst, + ) - return 1 - } + clientSet, err := kubernetes.NewForConfig(restConfig) + if err != nil { + return nil, nil, "", fmt.Errorf("could not create kubernetes client: %w", err) + } - err = start() - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Error("Could not start engine") + k8sURL, err := url.Parse(restConfig.Host) + if err != nil { + return nil, nil, "", fmt.Errorf("could not parse kubernetes url %v: %w", restConfig.Host, err) + } - return 1 - } - } + if k8sURL.Port() == "" { + switch k8sURL.Scheme { + case "http": + k8sURL.Host = k8sURL.Host + ":80" + case "https": + k8sURL.Host = k8sURL.Host + ":443" } } -} -// Execute adds all child commands to the root command and sets flags appropriately. -// This is called by main.main(). It only needs to happen once to the rootCmd. -func Execute() { - if err := rootCmd.Execute(); err != nil { - fmt.Println(err) - os.Exit(1) + clusterName := k8sCfg.ClusterName + if clusterName == "" { + clusterName = k8sURL.Host } + + return k8sCfg, clientSet, clusterName, nil } // Watches k8s namespaces from the current state, sending new events for each change @@ -423,7 +381,7 @@ func init() { // will be global for your application. var logLevel string - rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "/etc/srcman/config/k8s-source.yaml", "config file path") + rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "/etc/srcman/config/source.yaml", "config file path") rootCmd.PersistentFlags().StringVar(&logLevel, "log", "info", "Set the log level. Valid values: panic, fatal, error, warn, info, debug, trace") rootCmd.PersistentFlags().Int("health-check-port", 8080, "The port on which to serve health check endpoints (/healthz/alive, /healthz/ready, /healthz)") @@ -447,7 +405,7 @@ func init() { cobra.CheckErr(viper.BindPFlags(rootCmd.PersistentFlags())) // Run this before we do anything to set up the loglevel - rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { if lvl, err := log.ParseLevel(logLevel); err == nil { log.SetLevel(lvl) } else { @@ -460,23 +418,29 @@ func init() { ))) // Bind flags that haven't been set to the values from viper of we have them + var bindErr error cmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { // Bind the flag to viper only if it has a non-empty default if f.DefValue != "" || f.Changed { - err := viper.BindPFlag(f.Name, f) - if err != nil { - log.WithError(err).Errorf("Could not bind flag %s to viper", f.Name) + if err := viper.BindPFlag(f.Name, f); err != nil { + bindErr = err } } }) + if bindErr != nil { + log.WithError(bindErr).Error("could not bind flag to viper") + return fmt.Errorf("could not bind flag to viper: %w", bindErr) + } if viper.GetBool("json-log") { logging.ConfigureLogrusJSON(log.StandardLogger()) } if err := tracing.InitTracerWithUpstreams("k8s-source", viper.GetString("honeycomb-api-key"), viper.GetString("sentry-dsn")); err != nil { - log.Fatal(err) + log.WithError(err).Error("could not init tracer") + return fmt.Errorf("could not init tracer: %w", err) } + return nil } // shut down tracing at the end of the process @@ -508,6 +472,8 @@ func (t TerminationLogHook) Levels() []log.Level { } func (t TerminationLogHook) Fire(e *log.Entry) error { + // shutdown tracing first to ensure all spans are flushed + tracing.ShutdownTracer(context.Background()) tLog, err := os.OpenFile("/dev/termination-log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return err diff --git a/k8s-source/deployments/overmind-kube-source/templates/deployment.yaml b/k8s-source/deployments/overmind-kube-source/templates/deployment.yaml index 18228f46..c826a4e3 100644 --- a/k8s-source/deployments/overmind-kube-source/templates/deployment.yaml +++ b/k8s-source/deployments/overmind-kube-source/templates/deployment.yaml @@ -52,7 +52,7 @@ spec: value: "8080" livenessProbe: httpGet: - path: /healthz + path: /healthz/alive port: 8080 initialDelaySeconds: 30 periodSeconds: 10 @@ -61,7 +61,7 @@ spec: successThreshold: 1 readinessProbe: httpGet: - path: /healthz + path: /healthz/ready port: 8080 initialDelaySeconds: 5 periodSeconds: 10 diff --git a/k8s-source/proc/proc.go b/k8s-source/proc/proc.go new file mode 100644 index 00000000..455219f1 --- /dev/null +++ b/k8s-source/proc/proc.go @@ -0,0 +1,43 @@ +package proc + +import ( + "fmt" + + "github.com/spf13/viper" +) + +// K8sConfig holds configuration for the k8s source read from viper. +type K8sConfig struct { + Kubeconfig string + ClusterName string + RateLimitQPS float64 + RateLimitBurst int + HealthCheckPort int +} + +// ConfigFromViper reads and validates k8s source configuration from viper. +// Kubeconfig may be empty (in-cluster config). Returns an error if rate limits +// or health-check-port are invalid. +func ConfigFromViper() (*K8sConfig, error) { + rateLimitQPS := viper.GetFloat64("rate-limit-qps") + rateLimitBurst := viper.GetInt("rate-limit-burst") + healthCheckPort := viper.GetInt("health-check-port") + + if rateLimitQPS <= 0 { + return nil, fmt.Errorf("rate-limit-qps must be positive, got %v", rateLimitQPS) + } + if rateLimitBurst <= 0 { + return nil, fmt.Errorf("rate-limit-burst must be positive, got %v", rateLimitBurst) + } + if healthCheckPort < 1 || healthCheckPort > 65535 { + return nil, fmt.Errorf("health-check-port must be between 1 and 65535, got %v", healthCheckPort) + } + + return &K8sConfig{ + Kubeconfig: viper.GetString("kubeconfig"), + ClusterName: viper.GetString("cluster-name"), + RateLimitQPS: rateLimitQPS, + RateLimitBurst: rateLimitBurst, + HealthCheckPort: healthCheckPort, + }, nil +} diff --git a/sdp-go/changes.go b/sdp-go/changes.go index 04ab64cc..63881d87 100644 --- a/sdp-go/changes.go +++ b/sdp-go/changes.go @@ -391,31 +391,65 @@ func validateRoutineChangesConfig(routineChangesConfigYAML *RoutineChangesYAML) return nil } +// TimelineEntryContentDescription returns a human-readable description of the +// entry's content based on its type. +func TimelineEntryContentDescription(entry *ChangeTimelineEntryV2) string { + switch c := entry.GetContent().(type) { + case *ChangeTimelineEntryV2_MappedItems: + return fmt.Sprintf("%d mapped items", len(c.MappedItems.GetMappedItems())) + case *ChangeTimelineEntryV2_CalculatedBlastRadius: + return fmt.Sprintf("%d items, %d edges", c.CalculatedBlastRadius.GetNumItems(), c.CalculatedBlastRadius.GetNumEdges()) + case *ChangeTimelineEntryV2_CalculatedRisks: + return fmt.Sprintf("%d risks", len(c.CalculatedRisks.GetRisks())) + case *ChangeTimelineEntryV2_CalculatedLabels: + return fmt.Sprintf("%d labels", len(c.CalculatedLabels.GetLabels())) + case *ChangeTimelineEntryV2_ChangeValidation: + return fmt.Sprintf("%d validation categories", len(c.ChangeValidation.GetValidationChecklist())) + case *ChangeTimelineEntryV2_FormHypotheses: + return fmt.Sprintf("%d hypotheses", c.FormHypotheses.GetNumHypotheses()) + case *ChangeTimelineEntryV2_InvestigateHypotheses: + return fmt.Sprintf("%d proven, %d disproven, %d investigating", + c.InvestigateHypotheses.GetNumProven(), + c.InvestigateHypotheses.GetNumDisproven(), + c.InvestigateHypotheses.GetNumInvestigating()) + case *ChangeTimelineEntryV2_RecordObservations: + return fmt.Sprintf("%d observations", c.RecordObservations.GetNumObservations()) + case *ChangeTimelineEntryV2_Error: + return c.Error + case *ChangeTimelineEntryV2_StatusMessage: + return c.StatusMessage + case *ChangeTimelineEntryV2_Empty, nil: + return "" + default: + return "" + } +} + // TimelineFindInProgressEntry returns the current running entry in the list of entries // The function handles the following cases: // - If the input slice is nil or empty, it returns an error. -// - The first entry that has a status of IN_PROGRESS, PENDING, or ERROR, it returns the entry's name, status, and a nil error. +// - The first entry that has a status of IN_PROGRESS, PENDING, or ERROR, it returns the entry's name, content description, status, and a nil error. // - If an entry has an unknown status, it returns an error. -// - If the timeline is complete it returns an empty string, DONE status, and a nil error. -func TimelineFindInProgressEntry(entries []*ChangeTimelineEntryV2) (string, ChangeTimelineEntryStatus, error) { +// - If the timeline is complete it returns an empty string, empty content description, DONE status, and a nil error. +func TimelineFindInProgressEntry(entries []*ChangeTimelineEntryV2) (string, string, ChangeTimelineEntryStatus, error) { if entries == nil { - return "", ChangeTimelineEntryStatus_UNSPECIFIED, errors.New("entries is nil") + return "", "", ChangeTimelineEntryStatus_UNSPECIFIED, errors.New("entries is nil") } if len(entries) == 0 { - return "", ChangeTimelineEntryStatus_UNSPECIFIED, errors.New("entries is empty") + return "", "", ChangeTimelineEntryStatus_UNSPECIFIED, errors.New("entries is empty") } for _, entry := range entries { switch entry.GetStatus() { case ChangeTimelineEntryStatus_IN_PROGRESS, ChangeTimelineEntryStatus_PENDING, ChangeTimelineEntryStatus_ERROR: // if the entry is in progress or about to start, or has an error(to be retried) - return entry.GetName(), entry.GetStatus(), nil + return entry.GetName(), TimelineEntryContentDescription(entry), entry.GetStatus(), nil case ChangeTimelineEntryStatus_UNSPECIFIED, ChangeTimelineEntryStatus_DONE: // do nothing default: - return "", ChangeTimelineEntryStatus_UNSPECIFIED, fmt.Errorf("unknown status: %s", entry.GetStatus().String()) + return "", "", ChangeTimelineEntryStatus_UNSPECIFIED, fmt.Errorf("unknown status: %s", entry.GetStatus().String()) } } - return "", ChangeTimelineEntryStatus_DONE, nil + return "", "", ChangeTimelineEntryStatus_DONE, nil } diff --git a/sdp-go/changes.pb.go b/sdp-go/changes.pb.go index 009b89ec..41245de2 100644 --- a/sdp-go/changes.pb.go +++ b/sdp-go/changes.pb.go @@ -26,10 +26,11 @@ const ( type MappedItemTimelineStatus int32 const ( - MappedItemTimelineStatus_MAPPED_ITEM_TIMELINE_STATUS_UNSPECIFIED MappedItemTimelineStatus = 0 - MappedItemTimelineStatus_MAPPED_ITEM_TIMELINE_STATUS_SUCCESS MappedItemTimelineStatus = 1 - MappedItemTimelineStatus_MAPPED_ITEM_TIMELINE_STATUS_ERROR MappedItemTimelineStatus = 2 - MappedItemTimelineStatus_MAPPED_ITEM_TIMELINE_STATUS_UNSUPPORTED MappedItemTimelineStatus = 3 + MappedItemTimelineStatus_MAPPED_ITEM_TIMELINE_STATUS_UNSPECIFIED MappedItemTimelineStatus = 0 + MappedItemTimelineStatus_MAPPED_ITEM_TIMELINE_STATUS_SUCCESS MappedItemTimelineStatus = 1 + MappedItemTimelineStatus_MAPPED_ITEM_TIMELINE_STATUS_ERROR MappedItemTimelineStatus = 2 + MappedItemTimelineStatus_MAPPED_ITEM_TIMELINE_STATUS_UNSUPPORTED MappedItemTimelineStatus = 3 + MappedItemTimelineStatus_MAPPED_ITEM_TIMELINE_STATUS_PENDING_CREATION MappedItemTimelineStatus = 4 ) // Enum value maps for MappedItemTimelineStatus. @@ -39,12 +40,14 @@ var ( 1: "MAPPED_ITEM_TIMELINE_STATUS_SUCCESS", 2: "MAPPED_ITEM_TIMELINE_STATUS_ERROR", 3: "MAPPED_ITEM_TIMELINE_STATUS_UNSUPPORTED", + 4: "MAPPED_ITEM_TIMELINE_STATUS_PENDING_CREATION", } MappedItemTimelineStatus_value = map[string]int32{ - "MAPPED_ITEM_TIMELINE_STATUS_UNSPECIFIED": 0, - "MAPPED_ITEM_TIMELINE_STATUS_SUCCESS": 1, - "MAPPED_ITEM_TIMELINE_STATUS_ERROR": 2, - "MAPPED_ITEM_TIMELINE_STATUS_UNSUPPORTED": 3, + "MAPPED_ITEM_TIMELINE_STATUS_UNSPECIFIED": 0, + "MAPPED_ITEM_TIMELINE_STATUS_SUCCESS": 1, + "MAPPED_ITEM_TIMELINE_STATUS_ERROR": 2, + "MAPPED_ITEM_TIMELINE_STATUS_UNSUPPORTED": 3, + "MAPPED_ITEM_TIMELINE_STATUS_PENDING_CREATION": 4, } ) @@ -75,6 +78,62 @@ func (MappedItemTimelineStatus) EnumDescriptor() ([]byte, []int) { return file_changes_proto_rawDescGZIP(), []int{0} } +// Explicit mapping status from CLI - allows CLI to communicate state instead of API inferring +type MappedItemMappingStatus int32 + +const ( + MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_UNSPECIFIED MappedItemMappingStatus = 0 + MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_SUCCESS MappedItemMappingStatus = 1 + MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_UNSUPPORTED MappedItemMappingStatus = 2 + MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_PENDING_CREATION MappedItemMappingStatus = 3 + MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_ERROR MappedItemMappingStatus = 4 +) + +// Enum value maps for MappedItemMappingStatus. +var ( + MappedItemMappingStatus_name = map[int32]string{ + 0: "MAPPED_ITEM_MAPPING_STATUS_UNSPECIFIED", + 1: "MAPPED_ITEM_MAPPING_STATUS_SUCCESS", + 2: "MAPPED_ITEM_MAPPING_STATUS_UNSUPPORTED", + 3: "MAPPED_ITEM_MAPPING_STATUS_PENDING_CREATION", + 4: "MAPPED_ITEM_MAPPING_STATUS_ERROR", + } + MappedItemMappingStatus_value = map[string]int32{ + "MAPPED_ITEM_MAPPING_STATUS_UNSPECIFIED": 0, + "MAPPED_ITEM_MAPPING_STATUS_SUCCESS": 1, + "MAPPED_ITEM_MAPPING_STATUS_UNSUPPORTED": 2, + "MAPPED_ITEM_MAPPING_STATUS_PENDING_CREATION": 3, + "MAPPED_ITEM_MAPPING_STATUS_ERROR": 4, + } +) + +func (x MappedItemMappingStatus) Enum() *MappedItemMappingStatus { + p := new(MappedItemMappingStatus) + *p = x + return p +} + +func (x MappedItemMappingStatus) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (MappedItemMappingStatus) Descriptor() protoreflect.EnumDescriptor { + return file_changes_proto_enumTypes[1].Descriptor() +} + +func (MappedItemMappingStatus) Type() protoreflect.EnumType { + return &file_changes_proto_enumTypes[1] +} + +func (x MappedItemMappingStatus) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use MappedItemMappingStatus.Descriptor instead. +func (MappedItemMappingStatus) EnumDescriptor() ([]byte, []int) { + return file_changes_proto_rawDescGZIP(), []int{1} +} + type HypothesisStatus int32 const ( @@ -120,11 +179,11 @@ func (x HypothesisStatus) String() string { } func (HypothesisStatus) Descriptor() protoreflect.EnumDescriptor { - return file_changes_proto_enumTypes[1].Descriptor() + return file_changes_proto_enumTypes[2].Descriptor() } func (HypothesisStatus) Type() protoreflect.EnumType { - return &file_changes_proto_enumTypes[1] + return &file_changes_proto_enumTypes[2] } func (x HypothesisStatus) Number() protoreflect.EnumNumber { @@ -133,7 +192,7 @@ func (x HypothesisStatus) Number() protoreflect.EnumNumber { // Deprecated: Use HypothesisStatus.Descriptor instead. func (HypothesisStatus) EnumDescriptor() ([]byte, []int) { - return file_changes_proto_rawDescGZIP(), []int{1} + return file_changes_proto_rawDescGZIP(), []int{2} } type ChangeTimelineEntryStatus int32 @@ -180,11 +239,11 @@ func (x ChangeTimelineEntryStatus) String() string { } func (ChangeTimelineEntryStatus) Descriptor() protoreflect.EnumDescriptor { - return file_changes_proto_enumTypes[2].Descriptor() + return file_changes_proto_enumTypes[3].Descriptor() } func (ChangeTimelineEntryStatus) Type() protoreflect.EnumType { - return &file_changes_proto_enumTypes[2] + return &file_changes_proto_enumTypes[3] } func (x ChangeTimelineEntryStatus) Number() protoreflect.EnumNumber { @@ -193,7 +252,7 @@ func (x ChangeTimelineEntryStatus) Number() protoreflect.EnumNumber { // Deprecated: Use ChangeTimelineEntryStatus.Descriptor instead. func (ChangeTimelineEntryStatus) EnumDescriptor() ([]byte, []int) { - return file_changes_proto_rawDescGZIP(), []int{2} + return file_changes_proto_rawDescGZIP(), []int{3} } type ItemDiffStatus int32 @@ -238,11 +297,11 @@ func (x ItemDiffStatus) String() string { } func (ItemDiffStatus) Descriptor() protoreflect.EnumDescriptor { - return file_changes_proto_enumTypes[3].Descriptor() + return file_changes_proto_enumTypes[4].Descriptor() } func (ItemDiffStatus) Type() protoreflect.EnumType { - return &file_changes_proto_enumTypes[3] + return &file_changes_proto_enumTypes[4] } func (x ItemDiffStatus) Number() protoreflect.EnumNumber { @@ -251,7 +310,7 @@ func (x ItemDiffStatus) Number() protoreflect.EnumNumber { // Deprecated: Use ItemDiffStatus.Descriptor instead. func (ItemDiffStatus) EnumDescriptor() ([]byte, []int) { - return file_changes_proto_rawDescGZIP(), []int{3} + return file_changes_proto_rawDescGZIP(), []int{4} } type ChangeOutputFormat int32 @@ -287,11 +346,11 @@ func (x ChangeOutputFormat) String() string { } func (ChangeOutputFormat) Descriptor() protoreflect.EnumDescriptor { - return file_changes_proto_enumTypes[4].Descriptor() + return file_changes_proto_enumTypes[5].Descriptor() } func (ChangeOutputFormat) Type() protoreflect.EnumType { - return &file_changes_proto_enumTypes[4] + return &file_changes_proto_enumTypes[5] } func (x ChangeOutputFormat) Number() protoreflect.EnumNumber { @@ -300,7 +359,7 @@ func (x ChangeOutputFormat) Number() protoreflect.EnumNumber { // Deprecated: Use ChangeOutputFormat.Descriptor instead. func (ChangeOutputFormat) EnumDescriptor() ([]byte, []int) { - return file_changes_proto_rawDescGZIP(), []int{4} + return file_changes_proto_rawDescGZIP(), []int{5} } type LabelType int32 @@ -336,11 +395,11 @@ func (x LabelType) String() string { } func (LabelType) Descriptor() protoreflect.EnumDescriptor { - return file_changes_proto_enumTypes[5].Descriptor() + return file_changes_proto_enumTypes[6].Descriptor() } func (LabelType) Type() protoreflect.EnumType { - return &file_changes_proto_enumTypes[5] + return &file_changes_proto_enumTypes[6] } func (x LabelType) Number() protoreflect.EnumNumber { @@ -349,7 +408,7 @@ func (x LabelType) Number() protoreflect.EnumNumber { // Deprecated: Use LabelType.Descriptor instead. func (LabelType) EnumDescriptor() ([]byte, []int) { - return file_changes_proto_rawDescGZIP(), []int{5} + return file_changes_proto_rawDescGZIP(), []int{6} } type ChangeStatus int32 @@ -403,11 +462,11 @@ func (x ChangeStatus) String() string { } func (ChangeStatus) Descriptor() protoreflect.EnumDescriptor { - return file_changes_proto_enumTypes[6].Descriptor() + return file_changes_proto_enumTypes[7].Descriptor() } func (ChangeStatus) Type() protoreflect.EnumType { - return &file_changes_proto_enumTypes[6] + return &file_changes_proto_enumTypes[7] } func (x ChangeStatus) Number() protoreflect.EnumNumber { @@ -416,7 +475,7 @@ func (x ChangeStatus) Number() protoreflect.EnumNumber { // Deprecated: Use ChangeStatus.Descriptor instead. func (ChangeStatus) EnumDescriptor() ([]byte, []int) { - return file_changes_proto_rawDescGZIP(), []int{6} + return file_changes_proto_rawDescGZIP(), []int{7} } type StartChangeResponse_State int32 @@ -459,11 +518,11 @@ func (x StartChangeResponse_State) String() string { } func (StartChangeResponse_State) Descriptor() protoreflect.EnumDescriptor { - return file_changes_proto_enumTypes[7].Descriptor() + return file_changes_proto_enumTypes[8].Descriptor() } func (StartChangeResponse_State) Type() protoreflect.EnumType { - return &file_changes_proto_enumTypes[7] + return &file_changes_proto_enumTypes[8] } func (x StartChangeResponse_State) Number() protoreflect.EnumNumber { @@ -515,11 +574,11 @@ func (x EndChangeResponse_State) String() string { } func (EndChangeResponse_State) Descriptor() protoreflect.EnumDescriptor { - return file_changes_proto_enumTypes[8].Descriptor() + return file_changes_proto_enumTypes[9].Descriptor() } func (EndChangeResponse_State) Type() protoreflect.EnumType { - return &file_changes_proto_enumTypes[8] + return &file_changes_proto_enumTypes[9] } func (x EndChangeResponse_State) Number() protoreflect.EnumNumber { @@ -567,11 +626,11 @@ func (x Risk_Severity) String() string { } func (Risk_Severity) Descriptor() protoreflect.EnumDescriptor { - return file_changes_proto_enumTypes[9].Descriptor() + return file_changes_proto_enumTypes[10].Descriptor() } func (Risk_Severity) Type() protoreflect.EnumType { - return &file_changes_proto_enumTypes[9] + return &file_changes_proto_enumTypes[10] } func (x Risk_Severity) Number() protoreflect.EnumNumber { @@ -580,7 +639,7 @@ func (x Risk_Severity) Number() protoreflect.EnumNumber { // Deprecated: Use Risk_Severity.Descriptor instead. func (Risk_Severity) EnumDescriptor() ([]byte, []int) { - return file_changes_proto_rawDescGZIP(), []int{87, 0} + return file_changes_proto_rawDescGZIP(), []int{89, 0} } type ChangeAnalysisStatus_Status int32 @@ -622,11 +681,11 @@ func (x ChangeAnalysisStatus_Status) String() string { } func (ChangeAnalysisStatus_Status) Descriptor() protoreflect.EnumDescriptor { - return file_changes_proto_enumTypes[10].Descriptor() + return file_changes_proto_enumTypes[11].Descriptor() } func (ChangeAnalysisStatus_Status) Type() protoreflect.EnumType { - return &file_changes_proto_enumTypes[10] + return &file_changes_proto_enumTypes[11] } func (x ChangeAnalysisStatus_Status) Number() protoreflect.EnumNumber { @@ -635,7 +694,7 @@ func (x ChangeAnalysisStatus_Status) Number() protoreflect.EnumNumber { // Deprecated: Use ChangeAnalysisStatus_Status.Descriptor instead. func (ChangeAnalysisStatus_Status) EnumDescriptor() ([]byte, []int) { - return file_changes_proto_rawDescGZIP(), []int{88, 0} + return file_changes_proto_rawDescGZIP(), []int{90, 0} } type LabelRule struct { @@ -2891,7 +2950,11 @@ type MappedItemDiff struct { MappingQuery *Query `protobuf:"bytes,2,opt,name=mappingQuery,proto3,oneof" json:"mappingQuery,omitempty"` // The error that was returned as part of the mapping process. This will be // empty if the mapping was successful. - MappingError *QueryError `protobuf:"bytes,3,opt,name=mappingError,proto3,oneof" json:"mappingError,omitempty"` + MappingError *QueryError `protobuf:"bytes,3,opt,name=mappingError,proto3,oneof" json:"mappingError,omitempty"` + // Explicit status from CLI - when set, API uses this instead of inferring. + // This allows CLI to distinguish between "unsupported resource type", + // "pending creation (doesn't exist yet)", and "actual mapping error". + MappingStatus *MappedItemMappingStatus `protobuf:"varint,4,opt,name=mapping_status,json=mappingStatus,proto3,enum=changes.MappedItemMappingStatus,oneof" json:"mapping_status,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -2947,6 +3010,13 @@ func (x *MappedItemDiff) GetMappingError() *QueryError { return nil } +func (x *MappedItemDiff) GetMappingStatus() MappedItemMappingStatus { + if x != nil && x.MappingStatus != nil { + return *x.MappingStatus + } + return MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_UNSPECIFIED +} + // StartChangeAnalysisRequest is used to start the change analysis process. This // will calculate various things blast radius, risks, auto-tagging etc. This // it contains overrides for the auto-tagging rules and the blast radius config @@ -5912,6 +5982,96 @@ func (x *EndChangeResponse) GetNumEdges() uint32 { return 0 } +type StartChangeSimpleResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StartChangeSimpleResponse) Reset() { + *x = StartChangeSimpleResponse{} + mi := &file_changes_proto_msgTypes[87] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StartChangeSimpleResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StartChangeSimpleResponse) ProtoMessage() {} + +func (x *StartChangeSimpleResponse) ProtoReflect() protoreflect.Message { + mi := &file_changes_proto_msgTypes[87] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StartChangeSimpleResponse.ProtoReflect.Descriptor instead. +func (*StartChangeSimpleResponse) Descriptor() ([]byte, []int) { + return file_changes_proto_rawDescGZIP(), []int{87} +} + +type EndChangeSimpleResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // True if the job was successfully enqueued (or queued to run after start-change) + Queued bool `protobuf:"varint,1,opt,name=queued,proto3" json:"queued,omitempty"` + // True if end-change was queued to run after start-change completes + QueuedAfterStart bool `protobuf:"varint,2,opt,name=queued_after_start,json=queuedAfterStart,proto3" json:"queued_after_start,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *EndChangeSimpleResponse) Reset() { + *x = EndChangeSimpleResponse{} + mi := &file_changes_proto_msgTypes[88] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *EndChangeSimpleResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EndChangeSimpleResponse) ProtoMessage() {} + +func (x *EndChangeSimpleResponse) ProtoReflect() protoreflect.Message { + mi := &file_changes_proto_msgTypes[88] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EndChangeSimpleResponse.ProtoReflect.Descriptor instead. +func (*EndChangeSimpleResponse) Descriptor() ([]byte, []int) { + return file_changes_proto_rawDescGZIP(), []int{88} +} + +func (x *EndChangeSimpleResponse) GetQueued() bool { + if x != nil { + return x.Queued + } + return false +} + +func (x *EndChangeSimpleResponse) GetQueuedAfterStart() bool { + if x != nil { + return x.QueuedAfterStart + } + return false +} + type Risk struct { state protoimpl.MessageState `protogen:"open.v1"` UUID []byte `protobuf:"bytes,5,opt,name=UUID,proto3" json:"UUID,omitempty"` @@ -5925,7 +6085,7 @@ type Risk struct { func (x *Risk) Reset() { *x = Risk{} - mi := &file_changes_proto_msgTypes[87] + mi := &file_changes_proto_msgTypes[89] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -5937,7 +6097,7 @@ func (x *Risk) String() string { func (*Risk) ProtoMessage() {} func (x *Risk) ProtoReflect() protoreflect.Message { - mi := &file_changes_proto_msgTypes[87] + mi := &file_changes_proto_msgTypes[89] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -5950,7 +6110,7 @@ func (x *Risk) ProtoReflect() protoreflect.Message { // Deprecated: Use Risk.ProtoReflect.Descriptor instead. func (*Risk) Descriptor() ([]byte, []int) { - return file_changes_proto_rawDescGZIP(), []int{87} + return file_changes_proto_rawDescGZIP(), []int{89} } func (x *Risk) GetUUID() []byte { @@ -5997,7 +6157,7 @@ type ChangeAnalysisStatus struct { func (x *ChangeAnalysisStatus) Reset() { *x = ChangeAnalysisStatus{} - mi := &file_changes_proto_msgTypes[88] + mi := &file_changes_proto_msgTypes[90] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -6009,7 +6169,7 @@ func (x *ChangeAnalysisStatus) String() string { func (*ChangeAnalysisStatus) ProtoMessage() {} func (x *ChangeAnalysisStatus) ProtoReflect() protoreflect.Message { - mi := &file_changes_proto_msgTypes[88] + mi := &file_changes_proto_msgTypes[90] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -6022,7 +6182,7 @@ func (x *ChangeAnalysisStatus) ProtoReflect() protoreflect.Message { // Deprecated: Use ChangeAnalysisStatus.ProtoReflect.Descriptor instead. func (*ChangeAnalysisStatus) Descriptor() ([]byte, []int) { - return file_changes_proto_rawDescGZIP(), []int{88} + return file_changes_proto_rawDescGZIP(), []int{90} } func (x *ChangeAnalysisStatus) GetStatus() ChangeAnalysisStatus_Status { @@ -6043,7 +6203,7 @@ type GenerateRiskFixRequest struct { func (x *GenerateRiskFixRequest) Reset() { *x = GenerateRiskFixRequest{} - mi := &file_changes_proto_msgTypes[89] + mi := &file_changes_proto_msgTypes[91] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -6055,7 +6215,7 @@ func (x *GenerateRiskFixRequest) String() string { func (*GenerateRiskFixRequest) ProtoMessage() {} func (x *GenerateRiskFixRequest) ProtoReflect() protoreflect.Message { - mi := &file_changes_proto_msgTypes[89] + mi := &file_changes_proto_msgTypes[91] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -6068,7 +6228,7 @@ func (x *GenerateRiskFixRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GenerateRiskFixRequest.ProtoReflect.Descriptor instead. func (*GenerateRiskFixRequest) Descriptor() ([]byte, []int) { - return file_changes_proto_rawDescGZIP(), []int{89} + return file_changes_proto_rawDescGZIP(), []int{91} } func (x *GenerateRiskFixRequest) GetRiskUUID() []byte { @@ -6088,7 +6248,7 @@ type GenerateRiskFixResponse struct { func (x *GenerateRiskFixResponse) Reset() { *x = GenerateRiskFixResponse{} - mi := &file_changes_proto_msgTypes[90] + mi := &file_changes_proto_msgTypes[92] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -6100,7 +6260,7 @@ func (x *GenerateRiskFixResponse) String() string { func (*GenerateRiskFixResponse) ProtoMessage() {} func (x *GenerateRiskFixResponse) ProtoReflect() protoreflect.Message { - mi := &file_changes_proto_msgTypes[90] + mi := &file_changes_proto_msgTypes[92] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -6113,7 +6273,7 @@ func (x *GenerateRiskFixResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GenerateRiskFixResponse.ProtoReflect.Descriptor instead. func (*GenerateRiskFixResponse) Descriptor() ([]byte, []int) { - return file_changes_proto_rawDescGZIP(), []int{90} + return file_changes_proto_rawDescGZIP(), []int{92} } func (x *GenerateRiskFixResponse) GetFixSuggestion() string { @@ -6143,7 +6303,7 @@ type ChangeMetadata_HealthChange struct { func (x *ChangeMetadata_HealthChange) Reset() { *x = ChangeMetadata_HealthChange{} - mi := &file_changes_proto_msgTypes[93] + mi := &file_changes_proto_msgTypes[95] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -6155,7 +6315,7 @@ func (x *ChangeMetadata_HealthChange) String() string { func (*ChangeMetadata_HealthChange) ProtoMessage() {} func (x *ChangeMetadata_HealthChange) ProtoReflect() protoreflect.Message { - mi := &file_changes_proto_msgTypes[93] + mi := &file_changes_proto_msgTypes[95] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -6352,13 +6512,15 @@ const file_changes_proto_rawDesc = "" + "changeUUID\x18\x01 \x01(\fR\n" + "changeUUID\"R\n" + " ListChangingItemsSummaryResponse\x12.\n" + - "\x05items\x18\x01 \x03(\v2\x18.changes.ItemDiffSummaryR\x05items\"\xc0\x01\n" + + "\x05items\x18\x01 \x03(\v2\x18.changes.ItemDiffSummaryR\x05items\"\xa1\x02\n" + "\x0eMappedItemDiff\x12%\n" + "\x04item\x18\x01 \x01(\v2\x11.changes.ItemDiffR\x04item\x12/\n" + "\fmappingQuery\x18\x02 \x01(\v2\x06.QueryH\x00R\fmappingQuery\x88\x01\x01\x124\n" + - "\fmappingError\x18\x03 \x01(\v2\v.QueryErrorH\x01R\fmappingError\x88\x01\x01B\x0f\n" + + "\fmappingError\x18\x03 \x01(\v2\v.QueryErrorH\x01R\fmappingError\x88\x01\x01\x12L\n" + + "\x0emapping_status\x18\x04 \x01(\x0e2 .changes.MappedItemMappingStatusH\x02R\rmappingStatus\x88\x01\x01B\x0f\n" + "\r_mappingQueryB\x0f\n" + - "\r_mappingError\"\xa1\x04\n" + + "\r_mappingErrorB\x11\n" + + "\x0f_mapping_status\"\xa1\x04\n" + "\x1aStartChangeAnalysisRequest\x12\x1e\n" + "\n" + "changeUUID\x18\x01 \x01(\fR\n" + @@ -6610,7 +6772,11 @@ const file_changes_proto_rawDesc = "" + "\x15STATE_TAKING_SNAPSHOT\x10\x01\x12\x19\n" + "\x15STATE_SAVING_SNAPSHOT\x10\x02\x12\x0e\n" + "\n" + - "STATE_DONE\x10\x03\"\x96\x02\n" + + "STATE_DONE\x10\x03\"\x1b\n" + + "\x19StartChangeSimpleResponse\"_\n" + + "\x17EndChangeSimpleResponse\x12\x16\n" + + "\x06queued\x18\x01 \x01(\bR\x06queued\x12,\n" + + "\x12queued_after_start\x18\x02 \x01(\bR\x10queuedAfterStart\"\x96\x02\n" + "\x04Risk\x12\x12\n" + "\x04UUID\x18\x05 \x01(\fR\x04UUID\x12\x14\n" + "\x05title\x18\x01 \x01(\tR\x05title\x122\n" + @@ -6634,12 +6800,19 @@ const file_changes_proto_rawDesc = "" + "\x16GenerateRiskFixRequest\x12\x1a\n" + "\briskUUID\x18\x01 \x01(\fR\briskUUID\"?\n" + "\x17GenerateRiskFixResponse\x12$\n" + - "\rfixSuggestion\x18\x01 \x01(\tR\rfixSuggestion*\xc4\x01\n" + + "\rfixSuggestion\x18\x01 \x01(\tR\rfixSuggestion*\xf6\x01\n" + "\x18MappedItemTimelineStatus\x12+\n" + "'MAPPED_ITEM_TIMELINE_STATUS_UNSPECIFIED\x10\x00\x12'\n" + "#MAPPED_ITEM_TIMELINE_STATUS_SUCCESS\x10\x01\x12%\n" + "!MAPPED_ITEM_TIMELINE_STATUS_ERROR\x10\x02\x12+\n" + - "'MAPPED_ITEM_TIMELINE_STATUS_UNSUPPORTED\x10\x03*\xf9\x01\n" + + "'MAPPED_ITEM_TIMELINE_STATUS_UNSUPPORTED\x10\x03\x120\n" + + ",MAPPED_ITEM_TIMELINE_STATUS_PENDING_CREATION\x10\x04*\xf0\x01\n" + + "\x17MappedItemMappingStatus\x12*\n" + + "&MAPPED_ITEM_MAPPING_STATUS_UNSPECIFIED\x10\x00\x12&\n" + + "\"MAPPED_ITEM_MAPPING_STATUS_SUCCESS\x10\x01\x12*\n" + + "&MAPPED_ITEM_MAPPING_STATUS_UNSUPPORTED\x10\x02\x12/\n" + + "+MAPPED_ITEM_MAPPING_STATUS_PENDING_CREATION\x10\x03\x12$\n" + + " MAPPED_ITEM_MAPPING_STATUS_ERROR\x10\x04*\xf9\x01\n" + "\x10HypothesisStatus\x12.\n" + "*INVESTIGATED_HYPOTHESIS_STATUS_UNSPECIFIED\x10\x00\x12*\n" + "&INVESTIGATED_HYPOTHESIS_STATUS_FORMING\x10\x01\x120\n" + @@ -6672,7 +6845,7 @@ const file_changes_proto_rawDesc = "" + "\x16CHANGE_STATUS_DEFINING\x10\x01\x12\x1b\n" + "\x17CHANGE_STATUS_HAPPENING\x10\x02\x12 \n" + "\x18CHANGE_STATUS_PROCESSING\x10\x03\x1a\x02\b\x01\x12\x16\n" + - "\x12CHANGE_STATUS_DONE\x10\x042\x87\x0f\n" + + "\x12CHANGE_STATUS_DONE\x10\x042\xad\x10\n" + "\x0eChangesService\x12H\n" + "\vListChanges\x12\x1b.changes.ListChangesRequest\x1a\x1c.changes.ListChangesResponse\x12`\n" + "\x13ListChangesByStatus\x12#.changes.ListChangesByStatusRequest\x1a$.changes.ListChangesByStatusResponse\x12K\n" + @@ -6688,6 +6861,8 @@ const file_changes_proto_rawDesc = "" + "\fRefreshState\x12\x1c.changes.RefreshStateRequest\x1a\x1d.changes.RefreshStateResponse\x12J\n" + "\vStartChange\x12\x1b.changes.StartChangeRequest\x1a\x1c.changes.StartChangeResponse0\x01\x12D\n" + "\tEndChange\x12\x19.changes.EndChangeRequest\x1a\x1a.changes.EndChangeResponse0\x01\x12T\n" + + "\x11StartChangeSimple\x12\x1b.changes.StartChangeRequest\x1a\".changes.StartChangeSimpleResponse\x12N\n" + + "\x0fEndChangeSimple\x12\x19.changes.EndChangeRequest\x1a .changes.EndChangeSimpleResponse\x12T\n" + "\x0fListHomeChanges\x12\x1f.changes.ListHomeChangesRequest\x1a .changes.ListHomeChangesResponse\x12`\n" + "\x13StartChangeAnalysis\x12#.changes.StartChangeAnalysisRequest\x1a$.changes.StartChangeAnalysisResponse\x12o\n" + "\x18ListChangingItemsSummary\x12(.changes.ListChangingItemsSummaryRequest\x1a).changes.ListChangingItemsSummaryResponse\x12<\n" + @@ -6717,303 +6892,311 @@ func file_changes_proto_rawDescGZIP() []byte { return file_changes_proto_rawDescData } -var file_changes_proto_enumTypes = make([]protoimpl.EnumInfo, 11) -var file_changes_proto_msgTypes = make([]protoimpl.MessageInfo, 95) +var file_changes_proto_enumTypes = make([]protoimpl.EnumInfo, 12) +var file_changes_proto_msgTypes = make([]protoimpl.MessageInfo, 97) var file_changes_proto_goTypes = []any{ (MappedItemTimelineStatus)(0), // 0: changes.MappedItemTimelineStatus - (HypothesisStatus)(0), // 1: changes.HypothesisStatus - (ChangeTimelineEntryStatus)(0), // 2: changes.ChangeTimelineEntryStatus - (ItemDiffStatus)(0), // 3: changes.ItemDiffStatus - (ChangeOutputFormat)(0), // 4: changes.ChangeOutputFormat - (LabelType)(0), // 5: changes.LabelType - (ChangeStatus)(0), // 6: changes.ChangeStatus - (StartChangeResponse_State)(0), // 7: changes.StartChangeResponse.State - (EndChangeResponse_State)(0), // 8: changes.EndChangeResponse.State - (Risk_Severity)(0), // 9: changes.Risk.Severity - (ChangeAnalysisStatus_Status)(0), // 10: changes.ChangeAnalysisStatus.Status - (*LabelRule)(nil), // 11: changes.LabelRule - (*LabelRuleMetadata)(nil), // 12: changes.LabelRuleMetadata - (*LabelRuleProperties)(nil), // 13: changes.LabelRuleProperties - (*ListLabelRulesRequest)(nil), // 14: changes.ListLabelRulesRequest - (*ListLabelRulesResponse)(nil), // 15: changes.ListLabelRulesResponse - (*CreateLabelRuleRequest)(nil), // 16: changes.CreateLabelRuleRequest - (*CreateLabelRuleResponse)(nil), // 17: changes.CreateLabelRuleResponse - (*GetLabelRuleRequest)(nil), // 18: changes.GetLabelRuleRequest - (*GetLabelRuleResponse)(nil), // 19: changes.GetLabelRuleResponse - (*UpdateLabelRuleRequest)(nil), // 20: changes.UpdateLabelRuleRequest - (*UpdateLabelRuleResponse)(nil), // 21: changes.UpdateLabelRuleResponse - (*DeleteLabelRuleRequest)(nil), // 22: changes.DeleteLabelRuleRequest - (*DeleteLabelRuleResponse)(nil), // 23: changes.DeleteLabelRuleResponse - (*TestLabelRuleRequest)(nil), // 24: changes.TestLabelRuleRequest - (*TestLabelRuleResponse)(nil), // 25: changes.TestLabelRuleResponse - (*ReapplyLabelRuleInTimeRangeRequest)(nil), // 26: changes.ReapplyLabelRuleInTimeRangeRequest - (*ReapplyLabelRuleInTimeRangeResponse)(nil), // 27: changes.ReapplyLabelRuleInTimeRangeResponse - (*GetHypothesesDetailsRequest)(nil), // 28: changes.GetHypothesesDetailsRequest - (*GetHypothesesDetailsResponse)(nil), // 29: changes.GetHypothesesDetailsResponse - (*HypothesesDetails)(nil), // 30: changes.HypothesesDetails - (*GetChangeTimelineV2Request)(nil), // 31: changes.GetChangeTimelineV2Request - (*GetChangeTimelineV2Response)(nil), // 32: changes.GetChangeTimelineV2Response - (*ChangeTimelineEntryV2)(nil), // 33: changes.ChangeTimelineEntryV2 - (*EmptyContent)(nil), // 34: changes.EmptyContent - (*MappedItemTimelineSummary)(nil), // 35: changes.MappedItemTimelineSummary - (*MappedItemsTimelineEntry)(nil), // 36: changes.MappedItemsTimelineEntry - (*CalculatedBlastRadiusTimelineEntry)(nil), // 37: changes.CalculatedBlastRadiusTimelineEntry - (*RecordObservationsTimelineEntry)(nil), // 38: changes.RecordObservationsTimelineEntry - (*FormHypothesesTimelineEntry)(nil), // 39: changes.FormHypothesesTimelineEntry - (*InvestigateHypothesesTimelineEntry)(nil), // 40: changes.InvestigateHypothesesTimelineEntry - (*HypothesisSummary)(nil), // 41: changes.HypothesisSummary - (*CalculatedRisksTimelineEntry)(nil), // 42: changes.CalculatedRisksTimelineEntry - (*CalculatedLabelsTimelineEntry)(nil), // 43: changes.CalculatedLabelsTimelineEntry - (*ChangeValidationTimelineEntry)(nil), // 44: changes.ChangeValidationTimelineEntry - (*ChangeValidationCategory)(nil), // 45: changes.ChangeValidationCategory - (*GetDiffRequest)(nil), // 46: changes.GetDiffRequest - (*GetDiffResponse)(nil), // 47: changes.GetDiffResponse - (*ListChangingItemsSummaryRequest)(nil), // 48: changes.ListChangingItemsSummaryRequest - (*ListChangingItemsSummaryResponse)(nil), // 49: changes.ListChangingItemsSummaryResponse - (*MappedItemDiff)(nil), // 50: changes.MappedItemDiff - (*StartChangeAnalysisRequest)(nil), // 51: changes.StartChangeAnalysisRequest - (*StartChangeAnalysisResponse)(nil), // 52: changes.StartChangeAnalysisResponse - (*ListHomeChangesRequest)(nil), // 53: changes.ListHomeChangesRequest - (*ChangeFiltersRequest)(nil), // 54: changes.ChangeFiltersRequest - (*ListHomeChangesResponse)(nil), // 55: changes.ListHomeChangesResponse - (*PopulateChangeFiltersRequest)(nil), // 56: changes.PopulateChangeFiltersRequest - (*PopulateChangeFiltersResponse)(nil), // 57: changes.PopulateChangeFiltersResponse - (*ItemDiffSummary)(nil), // 58: changes.ItemDiffSummary - (*ItemDiff)(nil), // 59: changes.ItemDiff - (*EnrichedTags)(nil), // 60: changes.EnrichedTags - (*TagValue)(nil), // 61: changes.TagValue - (*UserTagValue)(nil), // 62: changes.UserTagValue - (*AutoTagValue)(nil), // 63: changes.AutoTagValue - (*Label)(nil), // 64: changes.Label - (*ChangeSummary)(nil), // 65: changes.ChangeSummary - (*Change)(nil), // 66: changes.Change - (*ChangeMetadata)(nil), // 67: changes.ChangeMetadata - (*ChangeProperties)(nil), // 68: changes.ChangeProperties - (*GithubChangeInfo)(nil), // 69: changes.GithubChangeInfo - (*ListChangesRequest)(nil), // 70: changes.ListChangesRequest - (*ListChangesResponse)(nil), // 71: changes.ListChangesResponse - (*ListChangesByStatusRequest)(nil), // 72: changes.ListChangesByStatusRequest - (*ListChangesByStatusResponse)(nil), // 73: changes.ListChangesByStatusResponse - (*CreateChangeRequest)(nil), // 74: changes.CreateChangeRequest - (*CreateChangeResponse)(nil), // 75: changes.CreateChangeResponse - (*GetChangeRequest)(nil), // 76: changes.GetChangeRequest - (*GetChangeByTicketLinkRequest)(nil), // 77: changes.GetChangeByTicketLinkRequest - (*GetChangeSummaryRequest)(nil), // 78: changes.GetChangeSummaryRequest - (*GetChangeSummaryResponse)(nil), // 79: changes.GetChangeSummaryResponse - (*GetChangeSignalsRequest)(nil), // 80: changes.GetChangeSignalsRequest - (*GetChangeSignalsResponse)(nil), // 81: changes.GetChangeSignalsResponse - (*GetChangeResponse)(nil), // 82: changes.GetChangeResponse - (*GetChangeRisksRequest)(nil), // 83: changes.GetChangeRisksRequest - (*ChangeRiskMetadata)(nil), // 84: changes.ChangeRiskMetadata - (*GetChangeRisksResponse)(nil), // 85: changes.GetChangeRisksResponse - (*UpdateChangeRequest)(nil), // 86: changes.UpdateChangeRequest - (*UpdateChangeResponse)(nil), // 87: changes.UpdateChangeResponse - (*DeleteChangeRequest)(nil), // 88: changes.DeleteChangeRequest - (*ListChangesBySnapshotUUIDRequest)(nil), // 89: changes.ListChangesBySnapshotUUIDRequest - (*ListChangesBySnapshotUUIDResponse)(nil), // 90: changes.ListChangesBySnapshotUUIDResponse - (*DeleteChangeResponse)(nil), // 91: changes.DeleteChangeResponse - (*RefreshStateRequest)(nil), // 92: changes.RefreshStateRequest - (*RefreshStateResponse)(nil), // 93: changes.RefreshStateResponse - (*StartChangeRequest)(nil), // 94: changes.StartChangeRequest - (*StartChangeResponse)(nil), // 95: changes.StartChangeResponse - (*EndChangeRequest)(nil), // 96: changes.EndChangeRequest - (*EndChangeResponse)(nil), // 97: changes.EndChangeResponse - (*Risk)(nil), // 98: changes.Risk - (*ChangeAnalysisStatus)(nil), // 99: changes.ChangeAnalysisStatus - (*GenerateRiskFixRequest)(nil), // 100: changes.GenerateRiskFixRequest - (*GenerateRiskFixResponse)(nil), // 101: changes.GenerateRiskFixResponse - nil, // 102: changes.EnrichedTags.TagValueEntry - nil, // 103: changes.ChangeSummary.TagsEntry - (*ChangeMetadata_HealthChange)(nil), // 104: changes.ChangeMetadata.HealthChange - nil, // 105: changes.ChangeProperties.TagsEntry - (*timestamppb.Timestamp)(nil), // 106: google.protobuf.Timestamp - (*Edge)(nil), // 107: Edge - (*Query)(nil), // 108: Query - (*QueryError)(nil), // 109: QueryError - (*BlastRadiusConfig)(nil), // 110: config.BlastRadiusConfig - (*RoutineChangesConfig)(nil), // 111: config.RoutineChangesConfig - (*GithubOrganisationProfile)(nil), // 112: config.GithubOrganisationProfile - (*PaginationRequest)(nil), // 113: PaginationRequest - (SortOrder)(0), // 114: SortOrder - (*PaginationResponse)(nil), // 115: PaginationResponse - (*Reference)(nil), // 116: Reference - (Health)(0), // 117: Health - (*Item)(nil), // 118: Item + (MappedItemMappingStatus)(0), // 1: changes.MappedItemMappingStatus + (HypothesisStatus)(0), // 2: changes.HypothesisStatus + (ChangeTimelineEntryStatus)(0), // 3: changes.ChangeTimelineEntryStatus + (ItemDiffStatus)(0), // 4: changes.ItemDiffStatus + (ChangeOutputFormat)(0), // 5: changes.ChangeOutputFormat + (LabelType)(0), // 6: changes.LabelType + (ChangeStatus)(0), // 7: changes.ChangeStatus + (StartChangeResponse_State)(0), // 8: changes.StartChangeResponse.State + (EndChangeResponse_State)(0), // 9: changes.EndChangeResponse.State + (Risk_Severity)(0), // 10: changes.Risk.Severity + (ChangeAnalysisStatus_Status)(0), // 11: changes.ChangeAnalysisStatus.Status + (*LabelRule)(nil), // 12: changes.LabelRule + (*LabelRuleMetadata)(nil), // 13: changes.LabelRuleMetadata + (*LabelRuleProperties)(nil), // 14: changes.LabelRuleProperties + (*ListLabelRulesRequest)(nil), // 15: changes.ListLabelRulesRequest + (*ListLabelRulesResponse)(nil), // 16: changes.ListLabelRulesResponse + (*CreateLabelRuleRequest)(nil), // 17: changes.CreateLabelRuleRequest + (*CreateLabelRuleResponse)(nil), // 18: changes.CreateLabelRuleResponse + (*GetLabelRuleRequest)(nil), // 19: changes.GetLabelRuleRequest + (*GetLabelRuleResponse)(nil), // 20: changes.GetLabelRuleResponse + (*UpdateLabelRuleRequest)(nil), // 21: changes.UpdateLabelRuleRequest + (*UpdateLabelRuleResponse)(nil), // 22: changes.UpdateLabelRuleResponse + (*DeleteLabelRuleRequest)(nil), // 23: changes.DeleteLabelRuleRequest + (*DeleteLabelRuleResponse)(nil), // 24: changes.DeleteLabelRuleResponse + (*TestLabelRuleRequest)(nil), // 25: changes.TestLabelRuleRequest + (*TestLabelRuleResponse)(nil), // 26: changes.TestLabelRuleResponse + (*ReapplyLabelRuleInTimeRangeRequest)(nil), // 27: changes.ReapplyLabelRuleInTimeRangeRequest + (*ReapplyLabelRuleInTimeRangeResponse)(nil), // 28: changes.ReapplyLabelRuleInTimeRangeResponse + (*GetHypothesesDetailsRequest)(nil), // 29: changes.GetHypothesesDetailsRequest + (*GetHypothesesDetailsResponse)(nil), // 30: changes.GetHypothesesDetailsResponse + (*HypothesesDetails)(nil), // 31: changes.HypothesesDetails + (*GetChangeTimelineV2Request)(nil), // 32: changes.GetChangeTimelineV2Request + (*GetChangeTimelineV2Response)(nil), // 33: changes.GetChangeTimelineV2Response + (*ChangeTimelineEntryV2)(nil), // 34: changes.ChangeTimelineEntryV2 + (*EmptyContent)(nil), // 35: changes.EmptyContent + (*MappedItemTimelineSummary)(nil), // 36: changes.MappedItemTimelineSummary + (*MappedItemsTimelineEntry)(nil), // 37: changes.MappedItemsTimelineEntry + (*CalculatedBlastRadiusTimelineEntry)(nil), // 38: changes.CalculatedBlastRadiusTimelineEntry + (*RecordObservationsTimelineEntry)(nil), // 39: changes.RecordObservationsTimelineEntry + (*FormHypothesesTimelineEntry)(nil), // 40: changes.FormHypothesesTimelineEntry + (*InvestigateHypothesesTimelineEntry)(nil), // 41: changes.InvestigateHypothesesTimelineEntry + (*HypothesisSummary)(nil), // 42: changes.HypothesisSummary + (*CalculatedRisksTimelineEntry)(nil), // 43: changes.CalculatedRisksTimelineEntry + (*CalculatedLabelsTimelineEntry)(nil), // 44: changes.CalculatedLabelsTimelineEntry + (*ChangeValidationTimelineEntry)(nil), // 45: changes.ChangeValidationTimelineEntry + (*ChangeValidationCategory)(nil), // 46: changes.ChangeValidationCategory + (*GetDiffRequest)(nil), // 47: changes.GetDiffRequest + (*GetDiffResponse)(nil), // 48: changes.GetDiffResponse + (*ListChangingItemsSummaryRequest)(nil), // 49: changes.ListChangingItemsSummaryRequest + (*ListChangingItemsSummaryResponse)(nil), // 50: changes.ListChangingItemsSummaryResponse + (*MappedItemDiff)(nil), // 51: changes.MappedItemDiff + (*StartChangeAnalysisRequest)(nil), // 52: changes.StartChangeAnalysisRequest + (*StartChangeAnalysisResponse)(nil), // 53: changes.StartChangeAnalysisResponse + (*ListHomeChangesRequest)(nil), // 54: changes.ListHomeChangesRequest + (*ChangeFiltersRequest)(nil), // 55: changes.ChangeFiltersRequest + (*ListHomeChangesResponse)(nil), // 56: changes.ListHomeChangesResponse + (*PopulateChangeFiltersRequest)(nil), // 57: changes.PopulateChangeFiltersRequest + (*PopulateChangeFiltersResponse)(nil), // 58: changes.PopulateChangeFiltersResponse + (*ItemDiffSummary)(nil), // 59: changes.ItemDiffSummary + (*ItemDiff)(nil), // 60: changes.ItemDiff + (*EnrichedTags)(nil), // 61: changes.EnrichedTags + (*TagValue)(nil), // 62: changes.TagValue + (*UserTagValue)(nil), // 63: changes.UserTagValue + (*AutoTagValue)(nil), // 64: changes.AutoTagValue + (*Label)(nil), // 65: changes.Label + (*ChangeSummary)(nil), // 66: changes.ChangeSummary + (*Change)(nil), // 67: changes.Change + (*ChangeMetadata)(nil), // 68: changes.ChangeMetadata + (*ChangeProperties)(nil), // 69: changes.ChangeProperties + (*GithubChangeInfo)(nil), // 70: changes.GithubChangeInfo + (*ListChangesRequest)(nil), // 71: changes.ListChangesRequest + (*ListChangesResponse)(nil), // 72: changes.ListChangesResponse + (*ListChangesByStatusRequest)(nil), // 73: changes.ListChangesByStatusRequest + (*ListChangesByStatusResponse)(nil), // 74: changes.ListChangesByStatusResponse + (*CreateChangeRequest)(nil), // 75: changes.CreateChangeRequest + (*CreateChangeResponse)(nil), // 76: changes.CreateChangeResponse + (*GetChangeRequest)(nil), // 77: changes.GetChangeRequest + (*GetChangeByTicketLinkRequest)(nil), // 78: changes.GetChangeByTicketLinkRequest + (*GetChangeSummaryRequest)(nil), // 79: changes.GetChangeSummaryRequest + (*GetChangeSummaryResponse)(nil), // 80: changes.GetChangeSummaryResponse + (*GetChangeSignalsRequest)(nil), // 81: changes.GetChangeSignalsRequest + (*GetChangeSignalsResponse)(nil), // 82: changes.GetChangeSignalsResponse + (*GetChangeResponse)(nil), // 83: changes.GetChangeResponse + (*GetChangeRisksRequest)(nil), // 84: changes.GetChangeRisksRequest + (*ChangeRiskMetadata)(nil), // 85: changes.ChangeRiskMetadata + (*GetChangeRisksResponse)(nil), // 86: changes.GetChangeRisksResponse + (*UpdateChangeRequest)(nil), // 87: changes.UpdateChangeRequest + (*UpdateChangeResponse)(nil), // 88: changes.UpdateChangeResponse + (*DeleteChangeRequest)(nil), // 89: changes.DeleteChangeRequest + (*ListChangesBySnapshotUUIDRequest)(nil), // 90: changes.ListChangesBySnapshotUUIDRequest + (*ListChangesBySnapshotUUIDResponse)(nil), // 91: changes.ListChangesBySnapshotUUIDResponse + (*DeleteChangeResponse)(nil), // 92: changes.DeleteChangeResponse + (*RefreshStateRequest)(nil), // 93: changes.RefreshStateRequest + (*RefreshStateResponse)(nil), // 94: changes.RefreshStateResponse + (*StartChangeRequest)(nil), // 95: changes.StartChangeRequest + (*StartChangeResponse)(nil), // 96: changes.StartChangeResponse + (*EndChangeRequest)(nil), // 97: changes.EndChangeRequest + (*EndChangeResponse)(nil), // 98: changes.EndChangeResponse + (*StartChangeSimpleResponse)(nil), // 99: changes.StartChangeSimpleResponse + (*EndChangeSimpleResponse)(nil), // 100: changes.EndChangeSimpleResponse + (*Risk)(nil), // 101: changes.Risk + (*ChangeAnalysisStatus)(nil), // 102: changes.ChangeAnalysisStatus + (*GenerateRiskFixRequest)(nil), // 103: changes.GenerateRiskFixRequest + (*GenerateRiskFixResponse)(nil), // 104: changes.GenerateRiskFixResponse + nil, // 105: changes.EnrichedTags.TagValueEntry + nil, // 106: changes.ChangeSummary.TagsEntry + (*ChangeMetadata_HealthChange)(nil), // 107: changes.ChangeMetadata.HealthChange + nil, // 108: changes.ChangeProperties.TagsEntry + (*timestamppb.Timestamp)(nil), // 109: google.protobuf.Timestamp + (*Edge)(nil), // 110: Edge + (*Query)(nil), // 111: Query + (*QueryError)(nil), // 112: QueryError + (*BlastRadiusConfig)(nil), // 113: config.BlastRadiusConfig + (*RoutineChangesConfig)(nil), // 114: config.RoutineChangesConfig + (*GithubOrganisationProfile)(nil), // 115: config.GithubOrganisationProfile + (*PaginationRequest)(nil), // 116: PaginationRequest + (SortOrder)(0), // 117: SortOrder + (*PaginationResponse)(nil), // 118: PaginationResponse + (*Reference)(nil), // 119: Reference + (Health)(0), // 120: Health + (*Item)(nil), // 121: Item } var file_changes_proto_depIdxs = []int32{ - 12, // 0: changes.LabelRule.metadata:type_name -> changes.LabelRuleMetadata - 13, // 1: changes.LabelRule.properties:type_name -> changes.LabelRuleProperties - 106, // 2: changes.LabelRuleMetadata.createdAt:type_name -> google.protobuf.Timestamp - 106, // 3: changes.LabelRuleMetadata.updatedAt:type_name -> google.protobuf.Timestamp - 11, // 4: changes.ListLabelRulesResponse.rules:type_name -> changes.LabelRule - 13, // 5: changes.CreateLabelRuleRequest.properties:type_name -> changes.LabelRuleProperties - 11, // 6: changes.CreateLabelRuleResponse.rule:type_name -> changes.LabelRule - 11, // 7: changes.GetLabelRuleResponse.rule:type_name -> changes.LabelRule - 13, // 8: changes.UpdateLabelRuleRequest.properties:type_name -> changes.LabelRuleProperties - 11, // 9: changes.UpdateLabelRuleResponse.rule:type_name -> changes.LabelRule - 13, // 10: changes.TestLabelRuleRequest.properties:type_name -> changes.LabelRuleProperties - 64, // 11: changes.TestLabelRuleResponse.label:type_name -> changes.Label - 106, // 12: changes.ReapplyLabelRuleInTimeRangeRequest.startAt:type_name -> google.protobuf.Timestamp - 106, // 13: changes.ReapplyLabelRuleInTimeRangeRequest.endAt:type_name -> google.protobuf.Timestamp - 30, // 14: changes.GetHypothesesDetailsResponse.hypotheses:type_name -> changes.HypothesesDetails - 1, // 15: changes.HypothesesDetails.status:type_name -> changes.HypothesisStatus - 33, // 16: changes.GetChangeTimelineV2Response.entries:type_name -> changes.ChangeTimelineEntryV2 - 2, // 17: changes.ChangeTimelineEntryV2.status:type_name -> changes.ChangeTimelineEntryStatus - 106, // 18: changes.ChangeTimelineEntryV2.startedAt:type_name -> google.protobuf.Timestamp - 106, // 19: changes.ChangeTimelineEntryV2.endedAt:type_name -> google.protobuf.Timestamp - 36, // 20: changes.ChangeTimelineEntryV2.mappedItems:type_name -> changes.MappedItemsTimelineEntry - 37, // 21: changes.ChangeTimelineEntryV2.calculatedBlastRadius:type_name -> changes.CalculatedBlastRadiusTimelineEntry - 42, // 22: changes.ChangeTimelineEntryV2.calculatedRisks:type_name -> changes.CalculatedRisksTimelineEntry - 34, // 23: changes.ChangeTimelineEntryV2.empty:type_name -> changes.EmptyContent - 44, // 24: changes.ChangeTimelineEntryV2.changeValidation:type_name -> changes.ChangeValidationTimelineEntry - 43, // 25: changes.ChangeTimelineEntryV2.calculatedLabels:type_name -> changes.CalculatedLabelsTimelineEntry - 39, // 26: changes.ChangeTimelineEntryV2.formHypotheses:type_name -> changes.FormHypothesesTimelineEntry - 40, // 27: changes.ChangeTimelineEntryV2.investigateHypotheses:type_name -> changes.InvestigateHypothesesTimelineEntry - 38, // 28: changes.ChangeTimelineEntryV2.recordObservations:type_name -> changes.RecordObservationsTimelineEntry + 13, // 0: changes.LabelRule.metadata:type_name -> changes.LabelRuleMetadata + 14, // 1: changes.LabelRule.properties:type_name -> changes.LabelRuleProperties + 109, // 2: changes.LabelRuleMetadata.createdAt:type_name -> google.protobuf.Timestamp + 109, // 3: changes.LabelRuleMetadata.updatedAt:type_name -> google.protobuf.Timestamp + 12, // 4: changes.ListLabelRulesResponse.rules:type_name -> changes.LabelRule + 14, // 5: changes.CreateLabelRuleRequest.properties:type_name -> changes.LabelRuleProperties + 12, // 6: changes.CreateLabelRuleResponse.rule:type_name -> changes.LabelRule + 12, // 7: changes.GetLabelRuleResponse.rule:type_name -> changes.LabelRule + 14, // 8: changes.UpdateLabelRuleRequest.properties:type_name -> changes.LabelRuleProperties + 12, // 9: changes.UpdateLabelRuleResponse.rule:type_name -> changes.LabelRule + 14, // 10: changes.TestLabelRuleRequest.properties:type_name -> changes.LabelRuleProperties + 65, // 11: changes.TestLabelRuleResponse.label:type_name -> changes.Label + 109, // 12: changes.ReapplyLabelRuleInTimeRangeRequest.startAt:type_name -> google.protobuf.Timestamp + 109, // 13: changes.ReapplyLabelRuleInTimeRangeRequest.endAt:type_name -> google.protobuf.Timestamp + 31, // 14: changes.GetHypothesesDetailsResponse.hypotheses:type_name -> changes.HypothesesDetails + 2, // 15: changes.HypothesesDetails.status:type_name -> changes.HypothesisStatus + 34, // 16: changes.GetChangeTimelineV2Response.entries:type_name -> changes.ChangeTimelineEntryV2 + 3, // 17: changes.ChangeTimelineEntryV2.status:type_name -> changes.ChangeTimelineEntryStatus + 109, // 18: changes.ChangeTimelineEntryV2.startedAt:type_name -> google.protobuf.Timestamp + 109, // 19: changes.ChangeTimelineEntryV2.endedAt:type_name -> google.protobuf.Timestamp + 37, // 20: changes.ChangeTimelineEntryV2.mappedItems:type_name -> changes.MappedItemsTimelineEntry + 38, // 21: changes.ChangeTimelineEntryV2.calculatedBlastRadius:type_name -> changes.CalculatedBlastRadiusTimelineEntry + 43, // 22: changes.ChangeTimelineEntryV2.calculatedRisks:type_name -> changes.CalculatedRisksTimelineEntry + 35, // 23: changes.ChangeTimelineEntryV2.empty:type_name -> changes.EmptyContent + 45, // 24: changes.ChangeTimelineEntryV2.changeValidation:type_name -> changes.ChangeValidationTimelineEntry + 44, // 25: changes.ChangeTimelineEntryV2.calculatedLabels:type_name -> changes.CalculatedLabelsTimelineEntry + 40, // 26: changes.ChangeTimelineEntryV2.formHypotheses:type_name -> changes.FormHypothesesTimelineEntry + 41, // 27: changes.ChangeTimelineEntryV2.investigateHypotheses:type_name -> changes.InvestigateHypothesesTimelineEntry + 39, // 28: changes.ChangeTimelineEntryV2.recordObservations:type_name -> changes.RecordObservationsTimelineEntry 0, // 29: changes.MappedItemTimelineSummary.status:type_name -> changes.MappedItemTimelineStatus - 50, // 30: changes.MappedItemsTimelineEntry.mappedItems:type_name -> changes.MappedItemDiff - 35, // 31: changes.MappedItemsTimelineEntry.items:type_name -> changes.MappedItemTimelineSummary - 41, // 32: changes.FormHypothesesTimelineEntry.hypotheses:type_name -> changes.HypothesisSummary - 41, // 33: changes.InvestigateHypothesesTimelineEntry.hypotheses:type_name -> changes.HypothesisSummary - 1, // 34: changes.HypothesisSummary.status:type_name -> changes.HypothesisStatus - 98, // 35: changes.CalculatedRisksTimelineEntry.risks:type_name -> changes.Risk - 64, // 36: changes.CalculatedLabelsTimelineEntry.labels:type_name -> changes.Label - 45, // 37: changes.ChangeValidationTimelineEntry.validationChecklist:type_name -> changes.ChangeValidationCategory - 59, // 38: changes.GetDiffResponse.expectedItems:type_name -> changes.ItemDiff - 59, // 39: changes.GetDiffResponse.unexpectedItems:type_name -> changes.ItemDiff - 107, // 40: changes.GetDiffResponse.edges:type_name -> Edge - 59, // 41: changes.GetDiffResponse.missingItems:type_name -> changes.ItemDiff - 58, // 42: changes.ListChangingItemsSummaryResponse.items:type_name -> changes.ItemDiffSummary - 59, // 43: changes.MappedItemDiff.item:type_name -> changes.ItemDiff - 108, // 44: changes.MappedItemDiff.mappingQuery:type_name -> Query - 109, // 45: changes.MappedItemDiff.mappingError:type_name -> QueryError - 50, // 46: changes.StartChangeAnalysisRequest.changingItems:type_name -> changes.MappedItemDiff - 110, // 47: changes.StartChangeAnalysisRequest.blastRadiusConfigOverride:type_name -> config.BlastRadiusConfig - 111, // 48: changes.StartChangeAnalysisRequest.routineChangesConfigOverride:type_name -> config.RoutineChangesConfig - 112, // 49: changes.StartChangeAnalysisRequest.githubOrganisationProfileOverride:type_name -> config.GithubOrganisationProfile - 113, // 50: changes.ListHomeChangesRequest.pagination:type_name -> PaginationRequest - 54, // 51: changes.ListHomeChangesRequest.filters:type_name -> changes.ChangeFiltersRequest - 9, // 52: changes.ChangeFiltersRequest.risks:type_name -> changes.Risk.Severity - 6, // 53: changes.ChangeFiltersRequest.statuses:type_name -> changes.ChangeStatus - 114, // 54: changes.ChangeFiltersRequest.sortOrder:type_name -> SortOrder - 65, // 55: changes.ListHomeChangesResponse.changes:type_name -> changes.ChangeSummary - 115, // 56: changes.ListHomeChangesResponse.pagination:type_name -> PaginationResponse - 116, // 57: changes.ItemDiffSummary.item:type_name -> Reference - 3, // 58: changes.ItemDiffSummary.status:type_name -> changes.ItemDiffStatus - 117, // 59: changes.ItemDiffSummary.healthAfter:type_name -> Health - 116, // 60: changes.ItemDiff.item:type_name -> Reference - 3, // 61: changes.ItemDiff.status:type_name -> changes.ItemDiffStatus - 118, // 62: changes.ItemDiff.before:type_name -> Item - 118, // 63: changes.ItemDiff.after:type_name -> Item - 102, // 64: changes.EnrichedTags.tagValue:type_name -> changes.EnrichedTags.TagValueEntry - 62, // 65: changes.TagValue.userTagValue:type_name -> changes.UserTagValue - 63, // 66: changes.TagValue.autoTagValue:type_name -> changes.AutoTagValue - 5, // 67: changes.Label.type:type_name -> changes.LabelType - 6, // 68: changes.ChangeSummary.status:type_name -> changes.ChangeStatus - 106, // 69: changes.ChangeSummary.createdAt:type_name -> google.protobuf.Timestamp - 103, // 70: changes.ChangeSummary.tags:type_name -> changes.ChangeSummary.TagsEntry - 60, // 71: changes.ChangeSummary.enrichedTags:type_name -> changes.EnrichedTags - 64, // 72: changes.ChangeSummary.labels:type_name -> changes.Label - 69, // 73: changes.ChangeSummary.githubChangeInfo:type_name -> changes.GithubChangeInfo - 67, // 74: changes.Change.metadata:type_name -> changes.ChangeMetadata - 68, // 75: changes.Change.properties:type_name -> changes.ChangeProperties - 106, // 76: changes.ChangeMetadata.createdAt:type_name -> google.protobuf.Timestamp - 106, // 77: changes.ChangeMetadata.updatedAt:type_name -> google.protobuf.Timestamp - 6, // 78: changes.ChangeMetadata.status:type_name -> changes.ChangeStatus - 104, // 79: changes.ChangeMetadata.UnknownHealthChange:type_name -> changes.ChangeMetadata.HealthChange - 104, // 80: changes.ChangeMetadata.OkHealthChange:type_name -> changes.ChangeMetadata.HealthChange - 104, // 81: changes.ChangeMetadata.WarningHealthChange:type_name -> changes.ChangeMetadata.HealthChange - 104, // 82: changes.ChangeMetadata.ErrorHealthChange:type_name -> changes.ChangeMetadata.HealthChange - 104, // 83: changes.ChangeMetadata.PendingHealthChange:type_name -> changes.ChangeMetadata.HealthChange - 69, // 84: changes.ChangeMetadata.githubChangeInfo:type_name -> changes.GithubChangeInfo - 59, // 85: changes.ChangeProperties.plannedChanges:type_name -> changes.ItemDiff - 105, // 86: changes.ChangeProperties.tags:type_name -> changes.ChangeProperties.TagsEntry - 60, // 87: changes.ChangeProperties.enrichedTags:type_name -> changes.EnrichedTags - 64, // 88: changes.ChangeProperties.labels:type_name -> changes.Label - 66, // 89: changes.ListChangesResponse.changes:type_name -> changes.Change - 6, // 90: changes.ListChangesByStatusRequest.status:type_name -> changes.ChangeStatus - 66, // 91: changes.ListChangesByStatusResponse.changes:type_name -> changes.Change - 68, // 92: changes.CreateChangeRequest.properties:type_name -> changes.ChangeProperties - 66, // 93: changes.CreateChangeResponse.change:type_name -> changes.Change - 4, // 94: changes.GetChangeSummaryRequest.changeOutputFormat:type_name -> changes.ChangeOutputFormat - 9, // 95: changes.GetChangeSummaryRequest.riskSeverityFilter:type_name -> changes.Risk.Severity - 4, // 96: changes.GetChangeSignalsRequest.changeOutputFormat:type_name -> changes.ChangeOutputFormat - 66, // 97: changes.GetChangeResponse.change:type_name -> changes.Change - 99, // 98: changes.ChangeRiskMetadata.changeAnalysisStatus:type_name -> changes.ChangeAnalysisStatus - 98, // 99: changes.ChangeRiskMetadata.risks:type_name -> changes.Risk - 84, // 100: changes.GetChangeRisksResponse.changeRiskMetadata:type_name -> changes.ChangeRiskMetadata - 68, // 101: changes.UpdateChangeRequest.properties:type_name -> changes.ChangeProperties - 66, // 102: changes.UpdateChangeResponse.change:type_name -> changes.Change - 66, // 103: changes.ListChangesBySnapshotUUIDResponse.changes:type_name -> changes.Change - 7, // 104: changes.StartChangeResponse.state:type_name -> changes.StartChangeResponse.State - 8, // 105: changes.EndChangeResponse.state:type_name -> changes.EndChangeResponse.State - 9, // 106: changes.Risk.severity:type_name -> changes.Risk.Severity - 116, // 107: changes.Risk.relatedItems:type_name -> Reference - 10, // 108: changes.ChangeAnalysisStatus.status:type_name -> changes.ChangeAnalysisStatus.Status - 61, // 109: changes.EnrichedTags.TagValueEntry.value:type_name -> changes.TagValue - 70, // 110: changes.ChangesService.ListChanges:input_type -> changes.ListChangesRequest - 72, // 111: changes.ChangesService.ListChangesByStatus:input_type -> changes.ListChangesByStatusRequest - 74, // 112: changes.ChangesService.CreateChange:input_type -> changes.CreateChangeRequest - 76, // 113: changes.ChangesService.GetChange:input_type -> changes.GetChangeRequest - 77, // 114: changes.ChangesService.GetChangeByTicketLink:input_type -> changes.GetChangeByTicketLinkRequest - 78, // 115: changes.ChangesService.GetChangeSummary:input_type -> changes.GetChangeSummaryRequest - 31, // 116: changes.ChangesService.GetChangeTimelineV2:input_type -> changes.GetChangeTimelineV2Request - 83, // 117: changes.ChangesService.GetChangeRisks:input_type -> changes.GetChangeRisksRequest - 86, // 118: changes.ChangesService.UpdateChange:input_type -> changes.UpdateChangeRequest - 88, // 119: changes.ChangesService.DeleteChange:input_type -> changes.DeleteChangeRequest - 89, // 120: changes.ChangesService.ListChangesBySnapshotUUID:input_type -> changes.ListChangesBySnapshotUUIDRequest - 92, // 121: changes.ChangesService.RefreshState:input_type -> changes.RefreshStateRequest - 94, // 122: changes.ChangesService.StartChange:input_type -> changes.StartChangeRequest - 96, // 123: changes.ChangesService.EndChange:input_type -> changes.EndChangeRequest - 53, // 124: changes.ChangesService.ListHomeChanges:input_type -> changes.ListHomeChangesRequest - 51, // 125: changes.ChangesService.StartChangeAnalysis:input_type -> changes.StartChangeAnalysisRequest - 48, // 126: changes.ChangesService.ListChangingItemsSummary:input_type -> changes.ListChangingItemsSummaryRequest - 46, // 127: changes.ChangesService.GetDiff:input_type -> changes.GetDiffRequest - 56, // 128: changes.ChangesService.PopulateChangeFilters:input_type -> changes.PopulateChangeFiltersRequest - 100, // 129: changes.ChangesService.GenerateRiskFix:input_type -> changes.GenerateRiskFixRequest - 28, // 130: changes.ChangesService.GetHypothesesDetails:input_type -> changes.GetHypothesesDetailsRequest - 80, // 131: changes.ChangesService.GetChangeSignals:input_type -> changes.GetChangeSignalsRequest - 14, // 132: changes.LabelService.ListLabelRules:input_type -> changes.ListLabelRulesRequest - 16, // 133: changes.LabelService.CreateLabelRule:input_type -> changes.CreateLabelRuleRequest - 18, // 134: changes.LabelService.GetLabelRule:input_type -> changes.GetLabelRuleRequest - 20, // 135: changes.LabelService.UpdateLabelRule:input_type -> changes.UpdateLabelRuleRequest - 22, // 136: changes.LabelService.DeleteLabelRule:input_type -> changes.DeleteLabelRuleRequest - 24, // 137: changes.LabelService.TestLabelRule:input_type -> changes.TestLabelRuleRequest - 26, // 138: changes.LabelService.ReapplyLabelRuleInTimeRange:input_type -> changes.ReapplyLabelRuleInTimeRangeRequest - 71, // 139: changes.ChangesService.ListChanges:output_type -> changes.ListChangesResponse - 73, // 140: changes.ChangesService.ListChangesByStatus:output_type -> changes.ListChangesByStatusResponse - 75, // 141: changes.ChangesService.CreateChange:output_type -> changes.CreateChangeResponse - 82, // 142: changes.ChangesService.GetChange:output_type -> changes.GetChangeResponse - 82, // 143: changes.ChangesService.GetChangeByTicketLink:output_type -> changes.GetChangeResponse - 79, // 144: changes.ChangesService.GetChangeSummary:output_type -> changes.GetChangeSummaryResponse - 32, // 145: changes.ChangesService.GetChangeTimelineV2:output_type -> changes.GetChangeTimelineV2Response - 85, // 146: changes.ChangesService.GetChangeRisks:output_type -> changes.GetChangeRisksResponse - 87, // 147: changes.ChangesService.UpdateChange:output_type -> changes.UpdateChangeResponse - 91, // 148: changes.ChangesService.DeleteChange:output_type -> changes.DeleteChangeResponse - 90, // 149: changes.ChangesService.ListChangesBySnapshotUUID:output_type -> changes.ListChangesBySnapshotUUIDResponse - 93, // 150: changes.ChangesService.RefreshState:output_type -> changes.RefreshStateResponse - 95, // 151: changes.ChangesService.StartChange:output_type -> changes.StartChangeResponse - 97, // 152: changes.ChangesService.EndChange:output_type -> changes.EndChangeResponse - 55, // 153: changes.ChangesService.ListHomeChanges:output_type -> changes.ListHomeChangesResponse - 52, // 154: changes.ChangesService.StartChangeAnalysis:output_type -> changes.StartChangeAnalysisResponse - 49, // 155: changes.ChangesService.ListChangingItemsSummary:output_type -> changes.ListChangingItemsSummaryResponse - 47, // 156: changes.ChangesService.GetDiff:output_type -> changes.GetDiffResponse - 57, // 157: changes.ChangesService.PopulateChangeFilters:output_type -> changes.PopulateChangeFiltersResponse - 101, // 158: changes.ChangesService.GenerateRiskFix:output_type -> changes.GenerateRiskFixResponse - 29, // 159: changes.ChangesService.GetHypothesesDetails:output_type -> changes.GetHypothesesDetailsResponse - 81, // 160: changes.ChangesService.GetChangeSignals:output_type -> changes.GetChangeSignalsResponse - 15, // 161: changes.LabelService.ListLabelRules:output_type -> changes.ListLabelRulesResponse - 17, // 162: changes.LabelService.CreateLabelRule:output_type -> changes.CreateLabelRuleResponse - 19, // 163: changes.LabelService.GetLabelRule:output_type -> changes.GetLabelRuleResponse - 21, // 164: changes.LabelService.UpdateLabelRule:output_type -> changes.UpdateLabelRuleResponse - 23, // 165: changes.LabelService.DeleteLabelRule:output_type -> changes.DeleteLabelRuleResponse - 25, // 166: changes.LabelService.TestLabelRule:output_type -> changes.TestLabelRuleResponse - 27, // 167: changes.LabelService.ReapplyLabelRuleInTimeRange:output_type -> changes.ReapplyLabelRuleInTimeRangeResponse - 139, // [139:168] is the sub-list for method output_type - 110, // [110:139] is the sub-list for method input_type - 110, // [110:110] is the sub-list for extension type_name - 110, // [110:110] is the sub-list for extension extendee - 0, // [0:110] is the sub-list for field type_name + 51, // 30: changes.MappedItemsTimelineEntry.mappedItems:type_name -> changes.MappedItemDiff + 36, // 31: changes.MappedItemsTimelineEntry.items:type_name -> changes.MappedItemTimelineSummary + 42, // 32: changes.FormHypothesesTimelineEntry.hypotheses:type_name -> changes.HypothesisSummary + 42, // 33: changes.InvestigateHypothesesTimelineEntry.hypotheses:type_name -> changes.HypothesisSummary + 2, // 34: changes.HypothesisSummary.status:type_name -> changes.HypothesisStatus + 101, // 35: changes.CalculatedRisksTimelineEntry.risks:type_name -> changes.Risk + 65, // 36: changes.CalculatedLabelsTimelineEntry.labels:type_name -> changes.Label + 46, // 37: changes.ChangeValidationTimelineEntry.validationChecklist:type_name -> changes.ChangeValidationCategory + 60, // 38: changes.GetDiffResponse.expectedItems:type_name -> changes.ItemDiff + 60, // 39: changes.GetDiffResponse.unexpectedItems:type_name -> changes.ItemDiff + 110, // 40: changes.GetDiffResponse.edges:type_name -> Edge + 60, // 41: changes.GetDiffResponse.missingItems:type_name -> changes.ItemDiff + 59, // 42: changes.ListChangingItemsSummaryResponse.items:type_name -> changes.ItemDiffSummary + 60, // 43: changes.MappedItemDiff.item:type_name -> changes.ItemDiff + 111, // 44: changes.MappedItemDiff.mappingQuery:type_name -> Query + 112, // 45: changes.MappedItemDiff.mappingError:type_name -> QueryError + 1, // 46: changes.MappedItemDiff.mapping_status:type_name -> changes.MappedItemMappingStatus + 51, // 47: changes.StartChangeAnalysisRequest.changingItems:type_name -> changes.MappedItemDiff + 113, // 48: changes.StartChangeAnalysisRequest.blastRadiusConfigOverride:type_name -> config.BlastRadiusConfig + 114, // 49: changes.StartChangeAnalysisRequest.routineChangesConfigOverride:type_name -> config.RoutineChangesConfig + 115, // 50: changes.StartChangeAnalysisRequest.githubOrganisationProfileOverride:type_name -> config.GithubOrganisationProfile + 116, // 51: changes.ListHomeChangesRequest.pagination:type_name -> PaginationRequest + 55, // 52: changes.ListHomeChangesRequest.filters:type_name -> changes.ChangeFiltersRequest + 10, // 53: changes.ChangeFiltersRequest.risks:type_name -> changes.Risk.Severity + 7, // 54: changes.ChangeFiltersRequest.statuses:type_name -> changes.ChangeStatus + 117, // 55: changes.ChangeFiltersRequest.sortOrder:type_name -> SortOrder + 66, // 56: changes.ListHomeChangesResponse.changes:type_name -> changes.ChangeSummary + 118, // 57: changes.ListHomeChangesResponse.pagination:type_name -> PaginationResponse + 119, // 58: changes.ItemDiffSummary.item:type_name -> Reference + 4, // 59: changes.ItemDiffSummary.status:type_name -> changes.ItemDiffStatus + 120, // 60: changes.ItemDiffSummary.healthAfter:type_name -> Health + 119, // 61: changes.ItemDiff.item:type_name -> Reference + 4, // 62: changes.ItemDiff.status:type_name -> changes.ItemDiffStatus + 121, // 63: changes.ItemDiff.before:type_name -> Item + 121, // 64: changes.ItemDiff.after:type_name -> Item + 105, // 65: changes.EnrichedTags.tagValue:type_name -> changes.EnrichedTags.TagValueEntry + 63, // 66: changes.TagValue.userTagValue:type_name -> changes.UserTagValue + 64, // 67: changes.TagValue.autoTagValue:type_name -> changes.AutoTagValue + 6, // 68: changes.Label.type:type_name -> changes.LabelType + 7, // 69: changes.ChangeSummary.status:type_name -> changes.ChangeStatus + 109, // 70: changes.ChangeSummary.createdAt:type_name -> google.protobuf.Timestamp + 106, // 71: changes.ChangeSummary.tags:type_name -> changes.ChangeSummary.TagsEntry + 61, // 72: changes.ChangeSummary.enrichedTags:type_name -> changes.EnrichedTags + 65, // 73: changes.ChangeSummary.labels:type_name -> changes.Label + 70, // 74: changes.ChangeSummary.githubChangeInfo:type_name -> changes.GithubChangeInfo + 68, // 75: changes.Change.metadata:type_name -> changes.ChangeMetadata + 69, // 76: changes.Change.properties:type_name -> changes.ChangeProperties + 109, // 77: changes.ChangeMetadata.createdAt:type_name -> google.protobuf.Timestamp + 109, // 78: changes.ChangeMetadata.updatedAt:type_name -> google.protobuf.Timestamp + 7, // 79: changes.ChangeMetadata.status:type_name -> changes.ChangeStatus + 107, // 80: changes.ChangeMetadata.UnknownHealthChange:type_name -> changes.ChangeMetadata.HealthChange + 107, // 81: changes.ChangeMetadata.OkHealthChange:type_name -> changes.ChangeMetadata.HealthChange + 107, // 82: changes.ChangeMetadata.WarningHealthChange:type_name -> changes.ChangeMetadata.HealthChange + 107, // 83: changes.ChangeMetadata.ErrorHealthChange:type_name -> changes.ChangeMetadata.HealthChange + 107, // 84: changes.ChangeMetadata.PendingHealthChange:type_name -> changes.ChangeMetadata.HealthChange + 70, // 85: changes.ChangeMetadata.githubChangeInfo:type_name -> changes.GithubChangeInfo + 60, // 86: changes.ChangeProperties.plannedChanges:type_name -> changes.ItemDiff + 108, // 87: changes.ChangeProperties.tags:type_name -> changes.ChangeProperties.TagsEntry + 61, // 88: changes.ChangeProperties.enrichedTags:type_name -> changes.EnrichedTags + 65, // 89: changes.ChangeProperties.labels:type_name -> changes.Label + 67, // 90: changes.ListChangesResponse.changes:type_name -> changes.Change + 7, // 91: changes.ListChangesByStatusRequest.status:type_name -> changes.ChangeStatus + 67, // 92: changes.ListChangesByStatusResponse.changes:type_name -> changes.Change + 69, // 93: changes.CreateChangeRequest.properties:type_name -> changes.ChangeProperties + 67, // 94: changes.CreateChangeResponse.change:type_name -> changes.Change + 5, // 95: changes.GetChangeSummaryRequest.changeOutputFormat:type_name -> changes.ChangeOutputFormat + 10, // 96: changes.GetChangeSummaryRequest.riskSeverityFilter:type_name -> changes.Risk.Severity + 5, // 97: changes.GetChangeSignalsRequest.changeOutputFormat:type_name -> changes.ChangeOutputFormat + 67, // 98: changes.GetChangeResponse.change:type_name -> changes.Change + 102, // 99: changes.ChangeRiskMetadata.changeAnalysisStatus:type_name -> changes.ChangeAnalysisStatus + 101, // 100: changes.ChangeRiskMetadata.risks:type_name -> changes.Risk + 85, // 101: changes.GetChangeRisksResponse.changeRiskMetadata:type_name -> changes.ChangeRiskMetadata + 69, // 102: changes.UpdateChangeRequest.properties:type_name -> changes.ChangeProperties + 67, // 103: changes.UpdateChangeResponse.change:type_name -> changes.Change + 67, // 104: changes.ListChangesBySnapshotUUIDResponse.changes:type_name -> changes.Change + 8, // 105: changes.StartChangeResponse.state:type_name -> changes.StartChangeResponse.State + 9, // 106: changes.EndChangeResponse.state:type_name -> changes.EndChangeResponse.State + 10, // 107: changes.Risk.severity:type_name -> changes.Risk.Severity + 119, // 108: changes.Risk.relatedItems:type_name -> Reference + 11, // 109: changes.ChangeAnalysisStatus.status:type_name -> changes.ChangeAnalysisStatus.Status + 62, // 110: changes.EnrichedTags.TagValueEntry.value:type_name -> changes.TagValue + 71, // 111: changes.ChangesService.ListChanges:input_type -> changes.ListChangesRequest + 73, // 112: changes.ChangesService.ListChangesByStatus:input_type -> changes.ListChangesByStatusRequest + 75, // 113: changes.ChangesService.CreateChange:input_type -> changes.CreateChangeRequest + 77, // 114: changes.ChangesService.GetChange:input_type -> changes.GetChangeRequest + 78, // 115: changes.ChangesService.GetChangeByTicketLink:input_type -> changes.GetChangeByTicketLinkRequest + 79, // 116: changes.ChangesService.GetChangeSummary:input_type -> changes.GetChangeSummaryRequest + 32, // 117: changes.ChangesService.GetChangeTimelineV2:input_type -> changes.GetChangeTimelineV2Request + 84, // 118: changes.ChangesService.GetChangeRisks:input_type -> changes.GetChangeRisksRequest + 87, // 119: changes.ChangesService.UpdateChange:input_type -> changes.UpdateChangeRequest + 89, // 120: changes.ChangesService.DeleteChange:input_type -> changes.DeleteChangeRequest + 90, // 121: changes.ChangesService.ListChangesBySnapshotUUID:input_type -> changes.ListChangesBySnapshotUUIDRequest + 93, // 122: changes.ChangesService.RefreshState:input_type -> changes.RefreshStateRequest + 95, // 123: changes.ChangesService.StartChange:input_type -> changes.StartChangeRequest + 97, // 124: changes.ChangesService.EndChange:input_type -> changes.EndChangeRequest + 95, // 125: changes.ChangesService.StartChangeSimple:input_type -> changes.StartChangeRequest + 97, // 126: changes.ChangesService.EndChangeSimple:input_type -> changes.EndChangeRequest + 54, // 127: changes.ChangesService.ListHomeChanges:input_type -> changes.ListHomeChangesRequest + 52, // 128: changes.ChangesService.StartChangeAnalysis:input_type -> changes.StartChangeAnalysisRequest + 49, // 129: changes.ChangesService.ListChangingItemsSummary:input_type -> changes.ListChangingItemsSummaryRequest + 47, // 130: changes.ChangesService.GetDiff:input_type -> changes.GetDiffRequest + 57, // 131: changes.ChangesService.PopulateChangeFilters:input_type -> changes.PopulateChangeFiltersRequest + 103, // 132: changes.ChangesService.GenerateRiskFix:input_type -> changes.GenerateRiskFixRequest + 29, // 133: changes.ChangesService.GetHypothesesDetails:input_type -> changes.GetHypothesesDetailsRequest + 81, // 134: changes.ChangesService.GetChangeSignals:input_type -> changes.GetChangeSignalsRequest + 15, // 135: changes.LabelService.ListLabelRules:input_type -> changes.ListLabelRulesRequest + 17, // 136: changes.LabelService.CreateLabelRule:input_type -> changes.CreateLabelRuleRequest + 19, // 137: changes.LabelService.GetLabelRule:input_type -> changes.GetLabelRuleRequest + 21, // 138: changes.LabelService.UpdateLabelRule:input_type -> changes.UpdateLabelRuleRequest + 23, // 139: changes.LabelService.DeleteLabelRule:input_type -> changes.DeleteLabelRuleRequest + 25, // 140: changes.LabelService.TestLabelRule:input_type -> changes.TestLabelRuleRequest + 27, // 141: changes.LabelService.ReapplyLabelRuleInTimeRange:input_type -> changes.ReapplyLabelRuleInTimeRangeRequest + 72, // 142: changes.ChangesService.ListChanges:output_type -> changes.ListChangesResponse + 74, // 143: changes.ChangesService.ListChangesByStatus:output_type -> changes.ListChangesByStatusResponse + 76, // 144: changes.ChangesService.CreateChange:output_type -> changes.CreateChangeResponse + 83, // 145: changes.ChangesService.GetChange:output_type -> changes.GetChangeResponse + 83, // 146: changes.ChangesService.GetChangeByTicketLink:output_type -> changes.GetChangeResponse + 80, // 147: changes.ChangesService.GetChangeSummary:output_type -> changes.GetChangeSummaryResponse + 33, // 148: changes.ChangesService.GetChangeTimelineV2:output_type -> changes.GetChangeTimelineV2Response + 86, // 149: changes.ChangesService.GetChangeRisks:output_type -> changes.GetChangeRisksResponse + 88, // 150: changes.ChangesService.UpdateChange:output_type -> changes.UpdateChangeResponse + 92, // 151: changes.ChangesService.DeleteChange:output_type -> changes.DeleteChangeResponse + 91, // 152: changes.ChangesService.ListChangesBySnapshotUUID:output_type -> changes.ListChangesBySnapshotUUIDResponse + 94, // 153: changes.ChangesService.RefreshState:output_type -> changes.RefreshStateResponse + 96, // 154: changes.ChangesService.StartChange:output_type -> changes.StartChangeResponse + 98, // 155: changes.ChangesService.EndChange:output_type -> changes.EndChangeResponse + 99, // 156: changes.ChangesService.StartChangeSimple:output_type -> changes.StartChangeSimpleResponse + 100, // 157: changes.ChangesService.EndChangeSimple:output_type -> changes.EndChangeSimpleResponse + 56, // 158: changes.ChangesService.ListHomeChanges:output_type -> changes.ListHomeChangesResponse + 53, // 159: changes.ChangesService.StartChangeAnalysis:output_type -> changes.StartChangeAnalysisResponse + 50, // 160: changes.ChangesService.ListChangingItemsSummary:output_type -> changes.ListChangingItemsSummaryResponse + 48, // 161: changes.ChangesService.GetDiff:output_type -> changes.GetDiffResponse + 58, // 162: changes.ChangesService.PopulateChangeFilters:output_type -> changes.PopulateChangeFiltersResponse + 104, // 163: changes.ChangesService.GenerateRiskFix:output_type -> changes.GenerateRiskFixResponse + 30, // 164: changes.ChangesService.GetHypothesesDetails:output_type -> changes.GetHypothesesDetailsResponse + 82, // 165: changes.ChangesService.GetChangeSignals:output_type -> changes.GetChangeSignalsResponse + 16, // 166: changes.LabelService.ListLabelRules:output_type -> changes.ListLabelRulesResponse + 18, // 167: changes.LabelService.CreateLabelRule:output_type -> changes.CreateLabelRuleResponse + 20, // 168: changes.LabelService.GetLabelRule:output_type -> changes.GetLabelRuleResponse + 22, // 169: changes.LabelService.UpdateLabelRule:output_type -> changes.UpdateLabelRuleResponse + 24, // 170: changes.LabelService.DeleteLabelRule:output_type -> changes.DeleteLabelRuleResponse + 26, // 171: changes.LabelService.TestLabelRule:output_type -> changes.TestLabelRuleResponse + 28, // 172: changes.LabelService.ReapplyLabelRuleInTimeRange:output_type -> changes.ReapplyLabelRuleInTimeRangeResponse + 142, // [142:173] is the sub-list for method output_type + 111, // [111:142] is the sub-list for method input_type + 111, // [111:111] is the sub-list for extension type_name + 111, // [111:111] is the sub-list for extension extendee + 0, // [0:111] is the sub-list for field type_name } func init() { file_changes_proto_init() } @@ -7053,8 +7236,8 @@ func file_changes_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_changes_proto_rawDesc), len(file_changes_proto_rawDesc)), - NumEnums: 11, - NumMessages: 95, + NumEnums: 12, + NumMessages: 97, NumExtensions: 0, NumServices: 2, }, diff --git a/sdp-go/changes_test.go b/sdp-go/changes_test.go index 2f4001d2..26300a24 100644 --- a/sdp-go/changes_test.go +++ b/sdp-go/changes_test.go @@ -109,7 +109,7 @@ func TestFindInProgressEntry(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - name, status, err := TimelineFindInProgressEntry(tt.entries) + name, _, status, err := TimelineFindInProgressEntry(tt.entries) if tt.expectError && err == nil { t.Errorf("Expected an error, got nil") @@ -130,6 +130,151 @@ func TestFindInProgressEntry(t *testing.T) { } } +func TestTimelineEntryContentDescription(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + entry *ChangeTimelineEntryV2 + expected string + }{ + { + name: "mapped items", + entry: &ChangeTimelineEntryV2{ + Content: &ChangeTimelineEntryV2_MappedItems{ + MappedItems: &MappedItemsTimelineEntry{ + MappedItems: []*MappedItemDiff{{}, {}, {}}, + }, + }, + }, + expected: "3 mapped items", + }, + { + name: "calculated blast radius", + entry: &ChangeTimelineEntryV2{ + Content: &ChangeTimelineEntryV2_CalculatedBlastRadius{ + CalculatedBlastRadius: &CalculatedBlastRadiusTimelineEntry{ + NumItems: 10, + NumEdges: 25, + }, + }, + }, + expected: "10 items, 25 edges", + }, + { + name: "calculated risks", + entry: &ChangeTimelineEntryV2{ + Content: &ChangeTimelineEntryV2_CalculatedRisks{ + CalculatedRisks: &CalculatedRisksTimelineEntry{ + Risks: []*Risk{{}, {}}, + }, + }, + }, + expected: "2 risks", + }, + { + name: "calculated labels", + entry: &ChangeTimelineEntryV2{ + Content: &ChangeTimelineEntryV2_CalculatedLabels{ + CalculatedLabels: &CalculatedLabelsTimelineEntry{ + Labels: []*Label{{}, {}, {}, {}}, + }, + }, + }, + expected: "4 labels", + }, + { + name: "change validation", + entry: &ChangeTimelineEntryV2{ + Content: &ChangeTimelineEntryV2_ChangeValidation{ + ChangeValidation: &ChangeValidationTimelineEntry{ + ValidationChecklist: []*ChangeValidationCategory{{}}, + }, + }, + }, + expected: "1 validation categories", + }, + { + name: "form hypotheses", + entry: &ChangeTimelineEntryV2{ + Content: &ChangeTimelineEntryV2_FormHypotheses{ + FormHypotheses: &FormHypothesesTimelineEntry{ + NumHypotheses: 5, + }, + }, + }, + expected: "5 hypotheses", + }, + { + name: "investigate hypotheses", + entry: &ChangeTimelineEntryV2{ + Content: &ChangeTimelineEntryV2_InvestigateHypotheses{ + InvestigateHypotheses: &InvestigateHypothesesTimelineEntry{ + NumProven: 2, + NumDisproven: 3, + NumInvestigating: 1, + }, + }, + }, + expected: "2 proven, 3 disproven, 1 investigating", + }, + { + name: "record observations", + entry: &ChangeTimelineEntryV2{ + Content: &ChangeTimelineEntryV2_RecordObservations{ + RecordObservations: &RecordObservationsTimelineEntry{ + NumObservations: 42, + }, + }, + }, + expected: "42 observations", + }, + { + name: "error content", + entry: &ChangeTimelineEntryV2{ + Content: &ChangeTimelineEntryV2_Error{ + Error: "something went wrong", + }, + }, + expected: "something went wrong", + }, + { + name: "status message", + entry: &ChangeTimelineEntryV2{ + Content: &ChangeTimelineEntryV2_StatusMessage{ + StatusMessage: "processing data", + }, + }, + expected: "processing data", + }, + { + name: "empty content", + entry: &ChangeTimelineEntryV2{ + Content: &ChangeTimelineEntryV2_Empty{ + Empty: &EmptyContent{}, + }, + }, + expected: "", + }, + { + name: "nil content", + entry: &ChangeTimelineEntryV2{ + Content: nil, + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := TimelineEntryContentDescription(tt.entry) + if result != tt.expected { + t.Errorf("Expected %q, got %q", tt.expected, result) + } + }) + } +} + func TestValidateRoutineChangesConfig(t *testing.T) { t.Parallel() diff --git a/sdp-go/sdpconnect/changes.connect.go b/sdp-go/sdpconnect/changes.connect.go index f0baba9b..5b828fa8 100644 --- a/sdp-go/sdpconnect/changes.connect.go +++ b/sdp-go/sdpconnect/changes.connect.go @@ -77,6 +77,12 @@ const ( // ChangesServiceEndChangeProcedure is the fully-qualified name of the ChangesService's EndChange // RPC. ChangesServiceEndChangeProcedure = "/changes.ChangesService/EndChange" + // ChangesServiceStartChangeSimpleProcedure is the fully-qualified name of the ChangesService's + // StartChangeSimple RPC. + ChangesServiceStartChangeSimpleProcedure = "/changes.ChangesService/StartChangeSimple" + // ChangesServiceEndChangeSimpleProcedure is the fully-qualified name of the ChangesService's + // EndChangeSimple RPC. + ChangesServiceEndChangeSimpleProcedure = "/changes.ChangesService/EndChangeSimple" // ChangesServiceListHomeChangesProcedure is the fully-qualified name of the ChangesService's // ListHomeChanges RPC. ChangesServiceListHomeChangesProcedure = "/changes.ChangesService/ListHomeChanges" @@ -160,6 +166,12 @@ type ChangesServiceClient interface { // the change diff and stores it as a list of DiffedItems and // advances the change status to `STATUS_DONE` EndChange(context.Context, *connect.Request[sdp_go.EndChangeRequest]) (*connect.ServerStreamForClient[sdp_go.EndChangeResponse], error) + // Simple version of StartChange that returns immediately after enqueuing the job. + // Use this instead of StartChange for non-streaming clients. + StartChangeSimple(context.Context, *connect.Request[sdp_go.StartChangeRequest]) (*connect.Response[sdp_go.StartChangeSimpleResponse], error) + // Simple version of EndChange that returns immediately after enqueuing the job. + // Use this instead of EndChange for non-streaming clients. + EndChangeSimple(context.Context, *connect.Request[sdp_go.EndChangeRequest]) (*connect.Response[sdp_go.EndChangeSimpleResponse], error) // Lists all changes, designed for use in the changes home page ListHomeChanges(context.Context, *connect.Request[sdp_go.ListHomeChangesRequest]) (*connect.Response[sdp_go.ListHomeChangesResponse], error) // Start the change analysis process. This will calculate various things @@ -284,6 +296,18 @@ func NewChangesServiceClient(httpClient connect.HTTPClient, baseURL string, opts connect.WithSchema(changesServiceMethods.ByName("EndChange")), connect.WithClientOptions(opts...), ), + startChangeSimple: connect.NewClient[sdp_go.StartChangeRequest, sdp_go.StartChangeSimpleResponse]( + httpClient, + baseURL+ChangesServiceStartChangeSimpleProcedure, + connect.WithSchema(changesServiceMethods.ByName("StartChangeSimple")), + connect.WithClientOptions(opts...), + ), + endChangeSimple: connect.NewClient[sdp_go.EndChangeRequest, sdp_go.EndChangeSimpleResponse]( + httpClient, + baseURL+ChangesServiceEndChangeSimpleProcedure, + connect.WithSchema(changesServiceMethods.ByName("EndChangeSimple")), + connect.WithClientOptions(opts...), + ), listHomeChanges: connect.NewClient[sdp_go.ListHomeChangesRequest, sdp_go.ListHomeChangesResponse]( httpClient, baseURL+ChangesServiceListHomeChangesProcedure, @@ -351,6 +375,8 @@ type changesServiceClient struct { refreshState *connect.Client[sdp_go.RefreshStateRequest, sdp_go.RefreshStateResponse] startChange *connect.Client[sdp_go.StartChangeRequest, sdp_go.StartChangeResponse] endChange *connect.Client[sdp_go.EndChangeRequest, sdp_go.EndChangeResponse] + startChangeSimple *connect.Client[sdp_go.StartChangeRequest, sdp_go.StartChangeSimpleResponse] + endChangeSimple *connect.Client[sdp_go.EndChangeRequest, sdp_go.EndChangeSimpleResponse] listHomeChanges *connect.Client[sdp_go.ListHomeChangesRequest, sdp_go.ListHomeChangesResponse] startChangeAnalysis *connect.Client[sdp_go.StartChangeAnalysisRequest, sdp_go.StartChangeAnalysisResponse] listChangingItemsSummary *connect.Client[sdp_go.ListChangingItemsSummaryRequest, sdp_go.ListChangingItemsSummaryResponse] @@ -431,6 +457,16 @@ func (c *changesServiceClient) EndChange(ctx context.Context, req *connect.Reque return c.endChange.CallServerStream(ctx, req) } +// StartChangeSimple calls changes.ChangesService.StartChangeSimple. +func (c *changesServiceClient) StartChangeSimple(ctx context.Context, req *connect.Request[sdp_go.StartChangeRequest]) (*connect.Response[sdp_go.StartChangeSimpleResponse], error) { + return c.startChangeSimple.CallUnary(ctx, req) +} + +// EndChangeSimple calls changes.ChangesService.EndChangeSimple. +func (c *changesServiceClient) EndChangeSimple(ctx context.Context, req *connect.Request[sdp_go.EndChangeRequest]) (*connect.Response[sdp_go.EndChangeSimpleResponse], error) { + return c.endChangeSimple.CallUnary(ctx, req) +} + // ListHomeChanges calls changes.ChangesService.ListHomeChanges. func (c *changesServiceClient) ListHomeChanges(ctx context.Context, req *connect.Request[sdp_go.ListHomeChangesRequest]) (*connect.Response[sdp_go.ListHomeChangesResponse], error) { return c.listHomeChanges.CallUnary(ctx, req) @@ -508,6 +544,12 @@ type ChangesServiceHandler interface { // the change diff and stores it as a list of DiffedItems and // advances the change status to `STATUS_DONE` EndChange(context.Context, *connect.Request[sdp_go.EndChangeRequest], *connect.ServerStream[sdp_go.EndChangeResponse]) error + // Simple version of StartChange that returns immediately after enqueuing the job. + // Use this instead of StartChange for non-streaming clients. + StartChangeSimple(context.Context, *connect.Request[sdp_go.StartChangeRequest]) (*connect.Response[sdp_go.StartChangeSimpleResponse], error) + // Simple version of EndChange that returns immediately after enqueuing the job. + // Use this instead of EndChange for non-streaming clients. + EndChangeSimple(context.Context, *connect.Request[sdp_go.EndChangeRequest]) (*connect.Response[sdp_go.EndChangeSimpleResponse], error) // Lists all changes, designed for use in the changes home page ListHomeChanges(context.Context, *connect.Request[sdp_go.ListHomeChangesRequest]) (*connect.Response[sdp_go.ListHomeChangesResponse], error) // Start the change analysis process. This will calculate various things @@ -628,6 +670,18 @@ func NewChangesServiceHandler(svc ChangesServiceHandler, opts ...connect.Handler connect.WithSchema(changesServiceMethods.ByName("EndChange")), connect.WithHandlerOptions(opts...), ) + changesServiceStartChangeSimpleHandler := connect.NewUnaryHandler( + ChangesServiceStartChangeSimpleProcedure, + svc.StartChangeSimple, + connect.WithSchema(changesServiceMethods.ByName("StartChangeSimple")), + connect.WithHandlerOptions(opts...), + ) + changesServiceEndChangeSimpleHandler := connect.NewUnaryHandler( + ChangesServiceEndChangeSimpleProcedure, + svc.EndChangeSimple, + connect.WithSchema(changesServiceMethods.ByName("EndChangeSimple")), + connect.WithHandlerOptions(opts...), + ) changesServiceListHomeChangesHandler := connect.NewUnaryHandler( ChangesServiceListHomeChangesProcedure, svc.ListHomeChanges, @@ -706,6 +760,10 @@ func NewChangesServiceHandler(svc ChangesServiceHandler, opts ...connect.Handler changesServiceStartChangeHandler.ServeHTTP(w, r) case ChangesServiceEndChangeProcedure: changesServiceEndChangeHandler.ServeHTTP(w, r) + case ChangesServiceStartChangeSimpleProcedure: + changesServiceStartChangeSimpleHandler.ServeHTTP(w, r) + case ChangesServiceEndChangeSimpleProcedure: + changesServiceEndChangeSimpleHandler.ServeHTTP(w, r) case ChangesServiceListHomeChangesProcedure: changesServiceListHomeChangesHandler.ServeHTTP(w, r) case ChangesServiceStartChangeAnalysisProcedure: @@ -787,6 +845,14 @@ func (UnimplementedChangesServiceHandler) EndChange(context.Context, *connect.Re return connect.NewError(connect.CodeUnimplemented, errors.New("changes.ChangesService.EndChange is not implemented")) } +func (UnimplementedChangesServiceHandler) StartChangeSimple(context.Context, *connect.Request[sdp_go.StartChangeRequest]) (*connect.Response[sdp_go.StartChangeSimpleResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("changes.ChangesService.StartChangeSimple is not implemented")) +} + +func (UnimplementedChangesServiceHandler) EndChangeSimple(context.Context, *connect.Request[sdp_go.EndChangeRequest]) (*connect.Response[sdp_go.EndChangeSimpleResponse], error) { + return nil, connect.NewError(connect.CodeUnimplemented, errors.New("changes.ChangesService.EndChangeSimple is not implemented")) +} + func (UnimplementedChangesServiceHandler) ListHomeChanges(context.Context, *connect.Request[sdp_go.ListHomeChangesRequest]) (*connect.Response[sdp_go.ListHomeChangesResponse], error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("changes.ChangesService.ListHomeChanges is not implemented")) } diff --git a/sources/azure/clients/storage-accounts-client.go b/sources/azure/clients/storage-accounts-client.go index 075bd1b0..ea7575ab 100644 --- a/sources/azure/clients/storage-accounts-client.go +++ b/sources/azure/clients/storage-accounts-client.go @@ -15,7 +15,7 @@ type StorageAccountsPager = Pager[armstorage.AccountsClientListByResourceGroupRe // StorageAccountsClient is an interface for interacting with Azure storage accounts type StorageAccountsClient interface { Get(ctx context.Context, resourceGroupName string, accountName string) (armstorage.AccountsClientGetPropertiesResponse, error) - List(resourceGroupName string) StorageAccountsPager + NewListByResourceGroupPager(resourceGroupName string, options *armstorage.AccountsClientListByResourceGroupOptions) StorageAccountsPager } type storageAccountsClient struct { @@ -26,8 +26,8 @@ func (a *storageAccountsClient) Get(ctx context.Context, resourceGroupName strin return a.client.GetProperties(ctx, resourceGroupName, accountName, nil) } -func (a *storageAccountsClient) List(resourceGroupName string) StorageAccountsPager { - return a.client.NewListByResourceGroupPager(resourceGroupName, nil) +func (a *storageAccountsClient) NewListByResourceGroupPager(resourceGroupName string, options *armstorage.AccountsClientListByResourceGroupOptions) StorageAccountsPager { + return a.client.NewListByResourceGroupPager(resourceGroupName, options) } // NewStorageAccountsClient creates a new StorageAccountsClient from the Azure SDK client diff --git a/sources/azure/cmd/root.go b/sources/azure/cmd/root.go index 78cde3ec..2800afe0 100644 --- a/sources/azure/cmd/root.go +++ b/sources/azure/cmd/root.go @@ -24,64 +24,67 @@ var cfgFile string // rootCmd represents the base command when called without any subcommands var rootCmd = &cobra.Command{ - Use: "azure-source", - Short: "Remote primary source for Azure", + Use: "azure-source", + Short: "Remote primary source for Azure", + SilenceUsage: true, Long: `This sources looks for Azure resources in your account. `, - Run: func(cmd *cobra.Command, args []string) { - ctx := context.Background() + RunE: func(cmd *cobra.Command, args []string) error { + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() defer tracing.LogRecoverToReturn(ctx, "azure-source.root") healthCheckPort := viper.GetInt("health-check-port") engineConfig, err := discovery.EngineConfigFromViper("azure", tracing.Version()) if err != nil { - log.WithError(err).Fatal("Could not create engine config") + log.WithError(err).Error("Could not create engine config") + return fmt.Errorf("could not create engine config: %w", err) } - err = engineConfig.CreateClients() + // Create a basic engine first so we can serve health probes and heartbeats even if init fails + e, err := discovery.NewEngine(engineConfig) if err != nil { sentry.CaptureException(err) - log.WithError(err).Fatal("could not auth create clients") + log.WithError(err).Error("Could not create engine") + return fmt.Errorf("could not create engine: %w", err) } - e, err := proc.Initialize(ctx, engineConfig, nil) - if err != nil { - log.WithError(err).Fatal("Could not initialize Azure source") - } - - e.StartSendingHeartbeats(ctx) - + // Serve health probes before initialization so they're available even on failure e.ServeHealthProbes(healthCheckPort) + // Start the engine (NATS connection) before adapter init so heartbeats work err = e.Start(ctx) if err != nil { - log.WithFields(log.Fields{ - "ovm.source.type": "azure", - "ovm.source.error": err, - }).Fatal("Could not start engine") + sentry.CaptureException(err) + log.WithError(err).Error("Could not start engine") + return fmt.Errorf("could not start engine: %w", err) } - sigs := make(chan os.Signal, 1) - - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + // Config validation (permanent errors — no retry, just idle with error) + azureCfg, cfgErr := proc.ConfigFromViper() + if cfgErr != nil { + log.WithError(cfgErr).Error("Azure source config error - pod will stay running with error status") + e.SetInitError(cfgErr) + sentry.CaptureException(cfgErr) + } else { + // Adapter init (retryable errors — backoff capped at 5 min) + e.InitialiseAdapters(ctx, func(ctx context.Context) error { + return proc.InitializeAdapters(ctx, e, azureCfg) + }) + } - <-sigs + <-ctx.Done() log.Info("Stopping engine") err = e.Stop() - if err != nil { - log.WithFields(log.Fields{ - "ovm.source.type": "azure", - "ovm.source.error": err, - }).Error("Could not stop engine") - - os.Exit(1) + log.WithError(err).Error("Could not stop engine") + return fmt.Errorf("could not stop engine: %w", err) } log.Info("Stopped") - os.Exit(0) + return nil }, } @@ -127,7 +130,7 @@ func init() { cobra.CheckErr(viper.BindPFlags(rootCmd.PersistentFlags())) // Run this before we do anything to set up the loglevel - rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { if lvl, err := log.ParseLevel(logLevel); err == nil { log.SetLevel(lvl) } else { @@ -140,23 +143,29 @@ func init() { log.AddHook(TerminationLogHook{}) // Bind flags that haven't been set to the values from viper of we have them + var bindErr error cmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { // Bind the flag to viper only if it has a non-empty default if f.DefValue != "" || f.Changed { - err := viper.BindPFlag(f.Name, f) - if err != nil { - log.WithError(err).Fatal("could not bind flag to viper") + if err := viper.BindPFlag(f.Name, f); err != nil { + bindErr = err } } }) + if bindErr != nil { + log.WithError(bindErr).Error("could not bind flag to viper") + return fmt.Errorf("could not bind flag to viper: %w", bindErr) + } if viper.GetBool("json-log") { logging.ConfigureLogrusJSON(log.StandardLogger()) } if err := tracing.InitTracerWithUpstreams("azure-source", viper.GetString("honeycomb-api-key"), viper.GetString("sentry-dsn")); err != nil { - log.Fatal(err) + log.WithError(err).Error("could not init tracer") + return fmt.Errorf("could not init tracer: %w", err) } + return nil } // shut down tracing at the end of the process rootCmd.PersistentPostRun = func(cmd *cobra.Command, args []string) { @@ -189,8 +198,7 @@ func (t TerminationLogHook) Levels() []log.Level { func (t TerminationLogHook) Fire(e *log.Entry) error { // shutdown tracing first to ensure all spans are flushed tracing.ShutdownTracer(context.Background()) - tLog, err := os.OpenFile("/dev/termination-log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - + tLog, err := os.OpenFile("/dev/termination-log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return err } diff --git a/sources/azure/integration-tests/authorization-role-assignment_test.go b/sources/azure/integration-tests/authorization-role-assignment_test.go index 180aa095..8a15c3e8 100644 --- a/sources/azure/integration-tests/authorization-role-assignment_test.go +++ b/sources/azure/integration-tests/authorization-role-assignment_test.go @@ -98,8 +98,7 @@ func TestAuthorizationRoleAssignmentIntegration(t *testing.T) { roleAssignmentWrapper := manual.NewAuthorizationRoleAssignment( clients.NewRoleAssignmentsClient(roleAssignmentsClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := roleAssignmentWrapper.Scopes()[0] @@ -150,8 +149,7 @@ func TestAuthorizationRoleAssignmentIntegration(t *testing.T) { roleAssignmentWrapper := manual.NewAuthorizationRoleAssignment( clients.NewRoleAssignmentsClient(roleAssignmentsClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := roleAssignmentWrapper.Scopes()[0] @@ -198,8 +196,7 @@ func TestAuthorizationRoleAssignmentIntegration(t *testing.T) { roleAssignmentWrapper := manual.NewAuthorizationRoleAssignment( clients.NewRoleAssignmentsClient(roleAssignmentsClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := roleAssignmentWrapper.Scopes()[0] @@ -248,8 +245,7 @@ func TestAuthorizationRoleAssignmentIntegration(t *testing.T) { roleAssignmentWrapper := manual.NewAuthorizationRoleAssignment( clients.NewRoleAssignmentsClient(roleAssignmentsClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := roleAssignmentWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/batch-batch-accounts_test.go b/sources/azure/integration-tests/batch-batch-accounts_test.go index 0610403f..6ef8d0aa 100644 --- a/sources/azure/integration-tests/batch-batch-accounts_test.go +++ b/sources/azure/integration-tests/batch-batch-accounts_test.go @@ -116,8 +116,7 @@ func TestBatchAccountIntegration(t *testing.T) { batchWrapper := manual.NewBatchAccount( clients.NewBatchAccountsClient(batchClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := batchWrapper.Scopes()[0] @@ -155,8 +154,7 @@ func TestBatchAccountIntegration(t *testing.T) { batchWrapper := manual.NewBatchAccount( clients.NewBatchAccountsClient(batchClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := batchWrapper.Scopes()[0] @@ -203,8 +201,7 @@ func TestBatchAccountIntegration(t *testing.T) { batchWrapper := manual.NewBatchAccount( clients.NewBatchAccountsClient(batchClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := batchWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/compute-availability-set_test.go b/sources/azure/integration-tests/compute-availability-set_test.go index e3054038..874d0f6f 100644 --- a/sources/azure/integration-tests/compute-availability-set_test.go +++ b/sources/azure/integration-tests/compute-availability-set_test.go @@ -161,8 +161,7 @@ func TestComputeAvailabilitySetIntegration(t *testing.T) { avSetWrapper := manual.NewComputeAvailabilitySet( clients.NewAvailabilitySetsClient(avSetClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := avSetWrapper.Scopes()[0] @@ -201,8 +200,7 @@ func TestComputeAvailabilitySetIntegration(t *testing.T) { avSetWrapper := manual.NewComputeAvailabilitySet( clients.NewAvailabilitySetsClient(avSetClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := avSetWrapper.Scopes()[0] @@ -249,8 +247,7 @@ func TestComputeAvailabilitySetIntegration(t *testing.T) { avSetWrapper := manual.NewComputeAvailabilitySet( clients.NewAvailabilitySetsClient(avSetClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := avSetWrapper.Scopes()[0] @@ -313,8 +310,7 @@ func TestComputeAvailabilitySetIntegration(t *testing.T) { avSetWrapper := manual.NewComputeAvailabilitySet( clients.NewAvailabilitySetsClient(avSetClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := avSetWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/compute-disk-encryption-set_test.go b/sources/azure/integration-tests/compute-disk-encryption-set_test.go index 5e41b600..a800dea6 100644 --- a/sources/azure/integration-tests/compute-disk-encryption-set_test.go +++ b/sources/azure/integration-tests/compute-disk-encryption-set_test.go @@ -143,8 +143,7 @@ func TestComputeDiskEncryptionSetIntegration(t *testing.T) { desWrapper := manual.NewComputeDiskEncryptionSet( clients.NewDiskEncryptionSetsClient(desClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := desWrapper.Scopes()[0] @@ -176,8 +175,7 @@ func TestComputeDiskEncryptionSetIntegration(t *testing.T) { desWrapper := manual.NewComputeDiskEncryptionSet( clients.NewDiskEncryptionSetsClient(desClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := desWrapper.Scopes()[0] @@ -214,8 +212,7 @@ func TestComputeDiskEncryptionSetIntegration(t *testing.T) { desWrapper := manual.NewComputeDiskEncryptionSet( clients.NewDiskEncryptionSetsClient(desClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := desWrapper.Scopes()[0] @@ -245,8 +242,7 @@ func TestComputeDiskEncryptionSetIntegration(t *testing.T) { desWrapper := manual.NewComputeDiskEncryptionSet( clients.NewDiskEncryptionSetsClient(desClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := desWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/compute-disk_test.go b/sources/azure/integration-tests/compute-disk_test.go index 8eb640ba..2932bbce 100644 --- a/sources/azure/integration-tests/compute-disk_test.go +++ b/sources/azure/integration-tests/compute-disk_test.go @@ -82,8 +82,7 @@ func TestComputeDiskIntegration(t *testing.T) { diskWrapper := manual.NewComputeDisk( clients.NewDisksClient(diskClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := diskWrapper.Scopes()[0] @@ -118,8 +117,7 @@ func TestComputeDiskIntegration(t *testing.T) { diskWrapper := manual.NewComputeDisk( clients.NewDisksClient(diskClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := diskWrapper.Scopes()[0] @@ -163,8 +161,7 @@ func TestComputeDiskIntegration(t *testing.T) { diskWrapper := manual.NewComputeDisk( clients.NewDisksClient(diskClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := diskWrapper.Scopes()[0] @@ -205,8 +202,7 @@ func TestComputeDiskIntegration(t *testing.T) { diskWrapper := manual.NewComputeDisk( clients.NewDisksClient(diskClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := diskWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/compute-image_test.go b/sources/azure/integration-tests/compute-image_test.go index 1e263013..a9c57f99 100644 --- a/sources/azure/integration-tests/compute-image_test.go +++ b/sources/azure/integration-tests/compute-image_test.go @@ -109,8 +109,7 @@ func TestComputeImageIntegration(t *testing.T) { imageWrapper := manual.NewComputeImage( clients.NewImagesClient(imageClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := imageWrapper.Scopes()[0] @@ -145,8 +144,7 @@ func TestComputeImageIntegration(t *testing.T) { imageWrapper := manual.NewComputeImage( clients.NewImagesClient(imageClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := imageWrapper.Scopes()[0] @@ -190,8 +188,7 @@ func TestComputeImageIntegration(t *testing.T) { imageWrapper := manual.NewComputeImage( clients.NewImagesClient(imageClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := imageWrapper.Scopes()[0] @@ -232,8 +229,7 @@ func TestComputeImageIntegration(t *testing.T) { imageWrapper := manual.NewComputeImage( clients.NewImagesClient(imageClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := imageWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/compute-proximity-placement-group_test.go b/sources/azure/integration-tests/compute-proximity-placement-group_test.go index 66ef4913..c2aae14a 100644 --- a/sources/azure/integration-tests/compute-proximity-placement-group_test.go +++ b/sources/azure/integration-tests/compute-proximity-placement-group_test.go @@ -86,8 +86,7 @@ func TestComputeProximityPlacementGroupIntegration(t *testing.T) { ppgWrapper := manual.NewComputeProximityPlacementGroup( clients.NewProximityPlacementGroupsClient(ppgClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := ppgWrapper.Scopes()[0] @@ -126,8 +125,7 @@ func TestComputeProximityPlacementGroupIntegration(t *testing.T) { ppgWrapper := manual.NewComputeProximityPlacementGroup( clients.NewProximityPlacementGroupsClient(ppgClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := ppgWrapper.Scopes()[0] @@ -173,8 +171,7 @@ func TestComputeProximityPlacementGroupIntegration(t *testing.T) { ppgWrapper := manual.NewComputeProximityPlacementGroup( clients.NewProximityPlacementGroupsClient(ppgClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := ppgWrapper.Scopes()[0] @@ -224,8 +221,7 @@ func TestComputeProximityPlacementGroupIntegration(t *testing.T) { ppgWrapper := manual.NewComputeProximityPlacementGroup( clients.NewProximityPlacementGroupsClient(ppgClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := ppgWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/compute-virtual-machine-extension_test.go b/sources/azure/integration-tests/compute-virtual-machine-extension_test.go index 731ce0f5..8da2091a 100644 --- a/sources/azure/integration-tests/compute-virtual-machine-extension_test.go +++ b/sources/azure/integration-tests/compute-virtual-machine-extension_test.go @@ -145,8 +145,7 @@ func TestComputeVirtualMachineExtensionIntegration(t *testing.T) { extensionWrapper := manual.NewComputeVirtualMachineExtension( clients.NewVirtualMachineExtensionsClient(extensionClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := extensionWrapper.Scopes()[0] @@ -192,8 +191,7 @@ func TestComputeVirtualMachineExtensionIntegration(t *testing.T) { extensionWrapper := manual.NewComputeVirtualMachineExtension( clients.NewVirtualMachineExtensionsClient(extensionClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := extensionWrapper.Scopes()[0] @@ -242,8 +240,7 @@ func TestComputeVirtualMachineExtensionIntegration(t *testing.T) { extensionWrapper := manual.NewComputeVirtualMachineExtension( clients.NewVirtualMachineExtensionsClient(extensionClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := extensionWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/compute-virtual-machine-run-command_test.go b/sources/azure/integration-tests/compute-virtual-machine-run-command_test.go index 942e0a88..67482ac1 100644 --- a/sources/azure/integration-tests/compute-virtual-machine-run-command_test.go +++ b/sources/azure/integration-tests/compute-virtual-machine-run-command_test.go @@ -145,8 +145,7 @@ func TestComputeVirtualMachineRunCommandIntegration(t *testing.T) { runCommandWrapper := manual.NewComputeVirtualMachineRunCommand( clients.NewVirtualMachineRunCommandsClient(runCommandClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := runCommandWrapper.Scopes()[0] @@ -192,8 +191,7 @@ func TestComputeVirtualMachineRunCommandIntegration(t *testing.T) { runCommandWrapper := manual.NewComputeVirtualMachineRunCommand( clients.NewVirtualMachineRunCommandsClient(runCommandClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := runCommandWrapper.Scopes()[0] @@ -242,8 +240,7 @@ func TestComputeVirtualMachineRunCommandIntegration(t *testing.T) { runCommandWrapper := manual.NewComputeVirtualMachineRunCommand( clients.NewVirtualMachineRunCommandsClient(runCommandClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := runCommandWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/compute-virtual-machine-scale-set_test.go b/sources/azure/integration-tests/compute-virtual-machine-scale-set_test.go index 2f2d5f4e..0abf3148 100644 --- a/sources/azure/integration-tests/compute-virtual-machine-scale-set_test.go +++ b/sources/azure/integration-tests/compute-virtual-machine-scale-set_test.go @@ -118,8 +118,7 @@ func TestComputeVirtualMachineScaleSetIntegration(t *testing.T) { vmssWrapper := manual.NewComputeVirtualMachineScaleSet( clients.NewVirtualMachineScaleSetsClient(vmssClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := vmssWrapper.Scopes()[0] @@ -158,8 +157,7 @@ func TestComputeVirtualMachineScaleSetIntegration(t *testing.T) { vmssWrapper := manual.NewComputeVirtualMachineScaleSet( clients.NewVirtualMachineScaleSetsClient(vmssClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := vmssWrapper.Scopes()[0] @@ -206,8 +204,7 @@ func TestComputeVirtualMachineScaleSetIntegration(t *testing.T) { vmssWrapper := manual.NewComputeVirtualMachineScaleSet( clients.NewVirtualMachineScaleSetsClient(vmssClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := vmssWrapper.Scopes()[0] @@ -290,8 +287,7 @@ func TestComputeVirtualMachineScaleSetIntegration(t *testing.T) { vmssWrapper := manual.NewComputeVirtualMachineScaleSet( clients.NewVirtualMachineScaleSetsClient(vmssClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := vmssWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/compute-virtual-machine_test.go b/sources/azure/integration-tests/compute-virtual-machine_test.go index 1c695570..5db54d03 100644 --- a/sources/azure/integration-tests/compute-virtual-machine_test.go +++ b/sources/azure/integration-tests/compute-virtual-machine_test.go @@ -128,8 +128,7 @@ func TestComputeVirtualMachineIntegration(t *testing.T) { vmWrapper := manual.NewComputeVirtualMachine( clients.NewVirtualMachinesClient(vmClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := vmWrapper.Scopes()[0] @@ -164,8 +163,7 @@ func TestComputeVirtualMachineIntegration(t *testing.T) { vmWrapper := manual.NewComputeVirtualMachine( clients.NewVirtualMachinesClient(vmClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := vmWrapper.Scopes()[0] @@ -209,8 +207,7 @@ func TestComputeVirtualMachineIntegration(t *testing.T) { vmWrapper := manual.NewComputeVirtualMachine( clients.NewVirtualMachinesClient(vmClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := vmWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/dbforpostgresql-database_test.go b/sources/azure/integration-tests/dbforpostgresql-database_test.go index 7915f01b..a5704f46 100644 --- a/sources/azure/integration-tests/dbforpostgresql-database_test.go +++ b/sources/azure/integration-tests/dbforpostgresql-database_test.go @@ -106,8 +106,7 @@ func TestDBforPostgreSQLDatabaseIntegration(t *testing.T) { pgDbWrapper := manual.NewDBforPostgreSQLDatabase( clients.NewPostgreSQLDatabasesClient(postgreSQLDatabaseClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := pgDbWrapper.Scopes()[0] @@ -160,8 +159,7 @@ func TestDBforPostgreSQLDatabaseIntegration(t *testing.T) { pgDbWrapper := manual.NewDBforPostgreSQLDatabase( clients.NewPostgreSQLDatabasesClient(postgreSQLDatabaseClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := pgDbWrapper.Scopes()[0] @@ -208,8 +206,7 @@ func TestDBforPostgreSQLDatabaseIntegration(t *testing.T) { pgDbWrapper := manual.NewDBforPostgreSQLDatabase( clients.NewPostgreSQLDatabasesClient(postgreSQLDatabaseClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := pgDbWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/dbforpostgresql-flexible-server_test.go b/sources/azure/integration-tests/dbforpostgresql-flexible-server_test.go index 9f09caf8..b42ef5e5 100644 --- a/sources/azure/integration-tests/dbforpostgresql-flexible-server_test.go +++ b/sources/azure/integration-tests/dbforpostgresql-flexible-server_test.go @@ -80,8 +80,7 @@ func TestDBforPostgreSQLFlexibleServerIntegration(t *testing.T) { pgServerWrapper := manual.NewDBforPostgreSQLFlexibleServer( clients.NewPostgreSQLFlexibleServersClient(postgreSQLServerClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := pgServerWrapper.Scopes()[0] @@ -131,8 +130,7 @@ func TestDBforPostgreSQLFlexibleServerIntegration(t *testing.T) { pgServerWrapper := manual.NewDBforPostgreSQLFlexibleServer( clients.NewPostgreSQLFlexibleServersClient(postgreSQLServerClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := pgServerWrapper.Scopes()[0] @@ -178,8 +176,7 @@ func TestDBforPostgreSQLFlexibleServerIntegration(t *testing.T) { pgServerWrapper := manual.NewDBforPostgreSQLFlexibleServer( clients.NewPostgreSQLFlexibleServersClient(postgreSQLServerClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := pgServerWrapper.Scopes()[0] @@ -267,8 +264,7 @@ func TestDBforPostgreSQLFlexibleServerIntegration(t *testing.T) { pgServerWrapper := manual.NewDBforPostgreSQLFlexibleServer( clients.NewPostgreSQLFlexibleServersClient(postgreSQLServerClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := pgServerWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/documentdb-database-accounts_test.go b/sources/azure/integration-tests/documentdb-database-accounts_test.go index 86f16b96..88d581de 100644 --- a/sources/azure/integration-tests/documentdb-database-accounts_test.go +++ b/sources/azure/integration-tests/documentdb-database-accounts_test.go @@ -87,8 +87,7 @@ func TestDocumentDBDatabaseAccountsIntegration(t *testing.T) { cosmosWrapper := manual.NewDocumentDBDatabaseAccounts( clients.NewDocumentDBDatabaseAccountsClient(cosmosClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := cosmosWrapper.Scopes()[0] @@ -128,8 +127,7 @@ func TestDocumentDBDatabaseAccountsIntegration(t *testing.T) { cosmosWrapper := manual.NewDocumentDBDatabaseAccounts( clients.NewDocumentDBDatabaseAccountsClient(cosmosClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := cosmosWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/keyvault-managed-hsm_test.go b/sources/azure/integration-tests/keyvault-managed-hsm_test.go index 3a3d8281..200e4b5d 100644 --- a/sources/azure/integration-tests/keyvault-managed-hsm_test.go +++ b/sources/azure/integration-tests/keyvault-managed-hsm_test.go @@ -110,8 +110,7 @@ func TestKeyVaultManagedHSMIntegration(t *testing.T) { hsmWrapper := manual.NewKeyVaultManagedHSM( clients.NewManagedHSMsClient(managedHSMClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := hsmWrapper.Scopes()[0] @@ -151,8 +150,7 @@ func TestKeyVaultManagedHSMIntegration(t *testing.T) { hsmWrapper := manual.NewKeyVaultManagedHSM( clients.NewManagedHSMsClient(managedHSMClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := hsmWrapper.Scopes()[0] @@ -195,8 +193,7 @@ func TestKeyVaultManagedHSMIntegration(t *testing.T) { hsmWrapper := manual.NewKeyVaultManagedHSM( clients.NewManagedHSMsClient(managedHSMClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := hsmWrapper.Scopes()[0] @@ -241,8 +238,7 @@ func TestKeyVaultManagedHSMIntegration(t *testing.T) { hsmWrapper := manual.NewKeyVaultManagedHSM( clients.NewManagedHSMsClient(managedHSMClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := hsmWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/keyvault-secret_test.go b/sources/azure/integration-tests/keyvault-secret_test.go index ef29775b..4f4939e9 100644 --- a/sources/azure/integration-tests/keyvault-secret_test.go +++ b/sources/azure/integration-tests/keyvault-secret_test.go @@ -119,8 +119,7 @@ func TestKeyVaultSecretIntegration(t *testing.T) { secretWrapper := manual.NewKeyVaultSecret( clients.NewSecretsClient(secretsClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := secretWrapper.Scopes()[0] @@ -165,8 +164,7 @@ func TestKeyVaultSecretIntegration(t *testing.T) { secretWrapper := manual.NewKeyVaultSecret( clients.NewSecretsClient(secretsClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := secretWrapper.Scopes()[0] @@ -211,8 +209,7 @@ func TestKeyVaultSecretIntegration(t *testing.T) { secretWrapper := manual.NewKeyVaultSecret( clients.NewSecretsClient(secretsClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := secretWrapper.Scopes()[0] @@ -254,8 +251,7 @@ func TestKeyVaultSecretIntegration(t *testing.T) { secretWrapper := manual.NewKeyVaultSecret( clients.NewSecretsClient(secretsClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := secretWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/keyvault-vault_test.go b/sources/azure/integration-tests/keyvault-vault_test.go index 529e700b..29d9d390 100644 --- a/sources/azure/integration-tests/keyvault-vault_test.go +++ b/sources/azure/integration-tests/keyvault-vault_test.go @@ -87,8 +87,7 @@ func TestKeyVaultVaultIntegration(t *testing.T) { kvWrapper := manual.NewKeyVaultVault( clients.NewVaultsClient(keyVaultClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := kvWrapper.Scopes()[0] @@ -128,8 +127,7 @@ func TestKeyVaultVaultIntegration(t *testing.T) { kvWrapper := manual.NewKeyVaultVault( clients.NewVaultsClient(keyVaultClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := kvWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/managedidentity-user-assigned-identity_test.go b/sources/azure/integration-tests/managedidentity-user-assigned-identity_test.go index 2b4ef595..3d1e32e2 100644 --- a/sources/azure/integration-tests/managedidentity-user-assigned-identity_test.go +++ b/sources/azure/integration-tests/managedidentity-user-assigned-identity_test.go @@ -88,8 +88,7 @@ func TestManagedIdentityUserAssignedIdentityIntegration(t *testing.T) { identityWrapper := manual.NewManagedIdentityUserAssignedIdentity( clients.NewUserAssignedIdentitiesClient(identityClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := identityWrapper.Scopes()[0] @@ -129,8 +128,7 @@ func TestManagedIdentityUserAssignedIdentityIntegration(t *testing.T) { identityWrapper := manual.NewManagedIdentityUserAssignedIdentity( clients.NewUserAssignedIdentitiesClient(identityClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := identityWrapper.Scopes()[0] @@ -183,8 +181,7 @@ func TestManagedIdentityUserAssignedIdentityIntegration(t *testing.T) { identityWrapper := manual.NewManagedIdentityUserAssignedIdentity( clients.NewUserAssignedIdentitiesClient(identityClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := identityWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/network-application-gateway_test.go b/sources/azure/integration-tests/network-application-gateway_test.go index f9ea1571..a06655fa 100644 --- a/sources/azure/integration-tests/network-application-gateway_test.go +++ b/sources/azure/integration-tests/network-application-gateway_test.go @@ -142,8 +142,7 @@ func TestNetworkApplicationGatewayIntegration(t *testing.T) { agWrapper := manual.NewNetworkApplicationGateway( clients.NewApplicationGatewaysClient(agClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := agWrapper.Scopes()[0] @@ -182,8 +181,7 @@ func TestNetworkApplicationGatewayIntegration(t *testing.T) { agWrapper := manual.NewNetworkApplicationGateway( clients.NewApplicationGatewaysClient(agClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := agWrapper.Scopes()[0] @@ -227,8 +225,7 @@ func TestNetworkApplicationGatewayIntegration(t *testing.T) { agWrapper := manual.NewNetworkApplicationGateway( clients.NewApplicationGatewaysClient(agClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := agWrapper.Scopes()[0] @@ -267,8 +264,7 @@ func TestNetworkApplicationGatewayIntegration(t *testing.T) { agWrapper := manual.NewNetworkApplicationGateway( clients.NewApplicationGatewaysClient(agClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := agWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/network-load-balancer_test.go b/sources/azure/integration-tests/network-load-balancer_test.go index 683ef648..eb3aae5e 100644 --- a/sources/azure/integration-tests/network-load-balancer_test.go +++ b/sources/azure/integration-tests/network-load-balancer_test.go @@ -125,8 +125,7 @@ func TestNetworkLoadBalancerIntegration(t *testing.T) { lbWrapper := manual.NewNetworkLoadBalancer( clients.NewLoadBalancersClient(lbClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := lbWrapper.Scopes()[0] @@ -165,8 +164,7 @@ func TestNetworkLoadBalancerIntegration(t *testing.T) { lbWrapper := manual.NewNetworkLoadBalancer( clients.NewLoadBalancersClient(lbClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := lbWrapper.Scopes()[0] @@ -221,8 +219,7 @@ func TestNetworkLoadBalancerIntegration(t *testing.T) { lbWrapper := manual.NewNetworkLoadBalancer( clients.NewLoadBalancersClient(lbClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := lbWrapper.Scopes()[0] @@ -256,8 +253,7 @@ func TestNetworkLoadBalancerIntegration(t *testing.T) { lbWrapper := manual.NewNetworkLoadBalancer( clients.NewLoadBalancersClient(lbClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := lbWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/network-network-interface_test.go b/sources/azure/integration-tests/network-network-interface_test.go index e9ec529a..828b0049 100644 --- a/sources/azure/integration-tests/network-network-interface_test.go +++ b/sources/azure/integration-tests/network-network-interface_test.go @@ -100,8 +100,7 @@ func TestNetworkNetworkInterfaceIntegration(t *testing.T) { nicWrapper := manual.NewNetworkNetworkInterface( clients.NewNetworkInterfacesClient(nicClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := nicWrapper.Scopes()[0] @@ -136,8 +135,7 @@ func TestNetworkNetworkInterfaceIntegration(t *testing.T) { nicWrapper := manual.NewNetworkNetworkInterface( clients.NewNetworkInterfacesClient(nicClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := nicWrapper.Scopes()[0] @@ -178,8 +176,7 @@ func TestNetworkNetworkInterfaceIntegration(t *testing.T) { nicWrapper := manual.NewNetworkNetworkInterface( clients.NewNetworkInterfacesClient(nicClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := nicWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/network-network-security-group_test.go b/sources/azure/integration-tests/network-network-security-group_test.go index a35ed229..a38c54bf 100644 --- a/sources/azure/integration-tests/network-network-security-group_test.go +++ b/sources/azure/integration-tests/network-network-security-group_test.go @@ -82,8 +82,7 @@ func TestNetworkNetworkSecurityGroupIntegration(t *testing.T) { nsgWrapper := manual.NewNetworkNetworkSecurityGroup( clients.NewNetworkSecurityGroupsClient(nsgClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := nsgWrapper.Scopes()[0] @@ -118,8 +117,7 @@ func TestNetworkNetworkSecurityGroupIntegration(t *testing.T) { nsgWrapper := manual.NewNetworkNetworkSecurityGroup( clients.NewNetworkSecurityGroupsClient(nsgClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := nsgWrapper.Scopes()[0] @@ -163,8 +161,7 @@ func TestNetworkNetworkSecurityGroupIntegration(t *testing.T) { nsgWrapper := manual.NewNetworkNetworkSecurityGroup( clients.NewNetworkSecurityGroupsClient(nsgClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := nsgWrapper.Scopes()[0] @@ -205,8 +202,7 @@ func TestNetworkNetworkSecurityGroupIntegration(t *testing.T) { nsgWrapper := manual.NewNetworkNetworkSecurityGroup( clients.NewNetworkSecurityGroupsClient(nsgClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := nsgWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/network-public-ip-address_test.go b/sources/azure/integration-tests/network-public-ip-address_test.go index e62e9503..e68169eb 100644 --- a/sources/azure/integration-tests/network-public-ip-address_test.go +++ b/sources/azure/integration-tests/network-public-ip-address_test.go @@ -126,8 +126,7 @@ func TestNetworkPublicIPAddressIntegration(t *testing.T) { publicIPWrapper := manual.NewNetworkPublicIPAddress( clients.NewPublicIPAddressesClient(publicIPClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := publicIPWrapper.Scopes()[0] @@ -178,8 +177,7 @@ func TestNetworkPublicIPAddressIntegration(t *testing.T) { publicIPWrapper := manual.NewNetworkPublicIPAddress( clients.NewPublicIPAddressesClient(publicIPClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := publicIPWrapper.Scopes()[0] @@ -223,8 +221,7 @@ func TestNetworkPublicIPAddressIntegration(t *testing.T) { publicIPWrapper := manual.NewNetworkPublicIPAddress( clients.NewPublicIPAddressesClient(publicIPClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := publicIPWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/network-route-table_test.go b/sources/azure/integration-tests/network-route-table_test.go index da109c92..c2d2a8b9 100644 --- a/sources/azure/integration-tests/network-route-table_test.go +++ b/sources/azure/integration-tests/network-route-table_test.go @@ -100,8 +100,7 @@ func TestNetworkRouteTableIntegration(t *testing.T) { routeTableWrapper := manual.NewNetworkRouteTable( clients.NewRouteTablesClient(routeTableClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := routeTableWrapper.Scopes()[0] @@ -136,8 +135,7 @@ func TestNetworkRouteTableIntegration(t *testing.T) { routeTableWrapper := manual.NewNetworkRouteTable( clients.NewRouteTablesClient(routeTableClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := routeTableWrapper.Scopes()[0] @@ -181,8 +179,7 @@ func TestNetworkRouteTableIntegration(t *testing.T) { routeTableWrapper := manual.NewNetworkRouteTable( clients.NewRouteTablesClient(routeTableClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := routeTableWrapper.Scopes()[0] @@ -223,8 +220,7 @@ func TestNetworkRouteTableIntegration(t *testing.T) { routeTableWrapper := manual.NewNetworkRouteTable( clients.NewRouteTablesClient(routeTableClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := routeTableWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/network-virtual-network_test.go b/sources/azure/integration-tests/network-virtual-network_test.go index 6e823e67..2c0e0e43 100644 --- a/sources/azure/integration-tests/network-virtual-network_test.go +++ b/sources/azure/integration-tests/network-virtual-network_test.go @@ -66,8 +66,7 @@ func TestNetworkVirtualNetworkIntegration(t *testing.T) { vnetWrapper := manual.NewNetworkVirtualNetwork( clients.NewVirtualNetworksClient(vnetClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := vnetWrapper.Scopes()[0] @@ -102,8 +101,7 @@ func TestNetworkVirtualNetworkIntegration(t *testing.T) { vnetWrapper := manual.NewNetworkVirtualNetwork( clients.NewVirtualNetworksClient(vnetClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := vnetWrapper.Scopes()[0] @@ -144,8 +142,7 @@ func TestNetworkVirtualNetworkIntegration(t *testing.T) { vnetWrapper := manual.NewNetworkVirtualNetwork( clients.NewVirtualNetworksClient(vnetClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := vnetWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/network-zone_test.go b/sources/azure/integration-tests/network-zone_test.go index 237c1415..ec54ad17 100644 --- a/sources/azure/integration-tests/network-zone_test.go +++ b/sources/azure/integration-tests/network-zone_test.go @@ -85,8 +85,7 @@ func TestNetworkZoneIntegration(t *testing.T) { zoneWrapper := manual.NewNetworkZone( clients.NewZonesClient(zonesClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := zoneWrapper.Scopes()[0] @@ -136,8 +135,7 @@ func TestNetworkZoneIntegration(t *testing.T) { zoneWrapper := manual.NewNetworkZone( clients.NewZonesClient(zonesClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := zoneWrapper.Scopes()[0] @@ -183,8 +181,7 @@ func TestNetworkZoneIntegration(t *testing.T) { zoneWrapper := manual.NewNetworkZone( clients.NewZonesClient(zonesClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := zoneWrapper.Scopes()[0] @@ -225,8 +222,7 @@ func TestNetworkZoneIntegration(t *testing.T) { zoneWrapper := manual.NewNetworkZone( clients.NewZonesClient(zonesClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := zoneWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/sql-database_test.go b/sources/azure/integration-tests/sql-database_test.go index ae959a0a..2646d385 100644 --- a/sources/azure/integration-tests/sql-database_test.go +++ b/sources/azure/integration-tests/sql-database_test.go @@ -106,8 +106,7 @@ func TestSQLDatabaseIntegration(t *testing.T) { sqlDbWrapper := manual.NewSqlDatabase( clients.NewSqlDatabasesClient(sqlDatabaseClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := sqlDbWrapper.Scopes()[0] @@ -160,8 +159,7 @@ func TestSQLDatabaseIntegration(t *testing.T) { sqlDbWrapper := manual.NewSqlDatabase( clients.NewSqlDatabasesClient(sqlDatabaseClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := sqlDbWrapper.Scopes()[0] @@ -208,8 +206,7 @@ func TestSQLDatabaseIntegration(t *testing.T) { sqlDbWrapper := manual.NewSqlDatabase( clients.NewSqlDatabasesClient(sqlDatabaseClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := sqlDbWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/sql-server_test.go b/sources/azure/integration-tests/sql-server_test.go index d7004a37..76a625cb 100644 --- a/sources/azure/integration-tests/sql-server_test.go +++ b/sources/azure/integration-tests/sql-server_test.go @@ -75,8 +75,7 @@ func TestSQLServerIntegration(t *testing.T) { sqlServerWrapper := manual.NewSqlServer( clients.NewSqlServersClient(sqlServerClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := sqlServerWrapper.Scopes()[0] @@ -126,8 +125,7 @@ func TestSQLServerIntegration(t *testing.T) { sqlServerWrapper := manual.NewSqlServer( clients.NewSqlServersClient(sqlServerClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := sqlServerWrapper.Scopes()[0] @@ -173,8 +171,7 @@ func TestSQLServerIntegration(t *testing.T) { sqlServerWrapper := manual.NewSqlServer( clients.NewSqlServersClient(sqlServerClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := sqlServerWrapper.Scopes()[0] @@ -268,8 +265,7 @@ func TestSQLServerIntegration(t *testing.T) { sqlServerWrapper := manual.NewSqlServer( clients.NewSqlServersClient(sqlServerClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := sqlServerWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/storage-account_test.go b/sources/azure/integration-tests/storage-account_test.go index 17872422..eb04c4cb 100644 --- a/sources/azure/integration-tests/storage-account_test.go +++ b/sources/azure/integration-tests/storage-account_test.go @@ -77,8 +77,7 @@ func TestStorageAccountIntegration(t *testing.T) { saWrapper := manual.NewStorageAccount( clients.NewStorageAccountsClient(saClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := saWrapper.Scopes()[0] @@ -116,8 +115,7 @@ func TestStorageAccountIntegration(t *testing.T) { saWrapper := manual.NewStorageAccount( clients.NewStorageAccountsClient(saClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := saWrapper.Scopes()[0] @@ -164,8 +162,7 @@ func TestStorageAccountIntegration(t *testing.T) { saWrapper := manual.NewStorageAccount( clients.NewStorageAccountsClient(saClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := saWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/storage-blob-container_test.go b/sources/azure/integration-tests/storage-blob-container_test.go index 5846d95c..8558917b 100644 --- a/sources/azure/integration-tests/storage-blob-container_test.go +++ b/sources/azure/integration-tests/storage-blob-container_test.go @@ -99,8 +99,7 @@ func TestStorageBlobContainerIntegration(t *testing.T) { bcWrapper := manual.NewStorageBlobContainer( clients.NewBlobContainersClient(bcClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := bcWrapper.Scopes()[0] @@ -136,8 +135,7 @@ func TestStorageBlobContainerIntegration(t *testing.T) { bcWrapper := manual.NewStorageBlobContainer( clients.NewBlobContainersClient(bcClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := bcWrapper.Scopes()[0] @@ -181,8 +179,7 @@ func TestStorageBlobContainerIntegration(t *testing.T) { bcWrapper := manual.NewStorageBlobContainer( clients.NewBlobContainersClient(bcClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := bcWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/storage-fileshare_test.go b/sources/azure/integration-tests/storage-fileshare_test.go index cf1c5453..0d2ccd6a 100644 --- a/sources/azure/integration-tests/storage-fileshare_test.go +++ b/sources/azure/integration-tests/storage-fileshare_test.go @@ -95,8 +95,7 @@ func TestStorageFileShareIntegration(t *testing.T) { fsWrapper := manual.NewStorageFileShare( clients.NewFileSharesClient(fsClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := fsWrapper.Scopes()[0] @@ -132,8 +131,7 @@ func TestStorageFileShareIntegration(t *testing.T) { fsWrapper := manual.NewStorageFileShare( clients.NewFileSharesClient(fsClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := fsWrapper.Scopes()[0] @@ -177,8 +175,7 @@ func TestStorageFileShareIntegration(t *testing.T) { fsWrapper := manual.NewStorageFileShare( clients.NewFileSharesClient(fsClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := fsWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/storage-queues_test.go b/sources/azure/integration-tests/storage-queues_test.go index bcc75c8b..0275ae33 100644 --- a/sources/azure/integration-tests/storage-queues_test.go +++ b/sources/azure/integration-tests/storage-queues_test.go @@ -94,8 +94,7 @@ func TestStorageQueuesIntegration(t *testing.T) { queueWrapper := manual.NewStorageQueues( clients.NewQueuesClient(queueClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := queueWrapper.Scopes()[0] @@ -132,8 +131,7 @@ func TestStorageQueuesIntegration(t *testing.T) { queueWrapper := manual.NewStorageQueues( clients.NewQueuesClient(queueClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := queueWrapper.Scopes()[0] @@ -178,8 +176,7 @@ func TestStorageQueuesIntegration(t *testing.T) { queueWrapper := manual.NewStorageQueues( clients.NewQueuesClient(queueClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := queueWrapper.Scopes()[0] @@ -221,8 +218,7 @@ func TestStorageQueuesIntegration(t *testing.T) { queueWrapper := manual.NewStorageQueues( clients.NewQueuesClient(queueClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := queueWrapper.Scopes()[0] diff --git a/sources/azure/integration-tests/storage-table_test.go b/sources/azure/integration-tests/storage-table_test.go index c5f551ec..51e7954b 100644 --- a/sources/azure/integration-tests/storage-table_test.go +++ b/sources/azure/integration-tests/storage-table_test.go @@ -94,8 +94,7 @@ func TestStorageTableIntegration(t *testing.T) { tableWrapper := manual.NewStorageTable( clients.NewTablesClient(tableClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := tableWrapper.Scopes()[0] @@ -132,8 +131,7 @@ func TestStorageTableIntegration(t *testing.T) { tableWrapper := manual.NewStorageTable( clients.NewTablesClient(tableClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := tableWrapper.Scopes()[0] @@ -178,8 +176,7 @@ func TestStorageTableIntegration(t *testing.T) { tableWrapper := manual.NewStorageTable( clients.NewTablesClient(tableClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := tableWrapper.Scopes()[0] @@ -221,8 +218,7 @@ func TestStorageTableIntegration(t *testing.T) { tableWrapper := manual.NewStorageTable( clients.NewTablesClient(tableClient), - subscriptionID, - integrationTestResourceGroup, + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, ) scope := tableWrapper.Scopes()[0] diff --git a/sources/azure/manual/.cursor/rules/azure-manual-adapter-creation.mdc b/sources/azure/manual/.cursor/rules/azure-manual-adapter-creation.mdc index 1f1f95cd..6cb89c24 100644 --- a/sources/azure/manual/.cursor/rules/azure-manual-adapter-creation.mdc +++ b/sources/azure/manual/.cursor/rules/azure-manual-adapter-creation.mdc @@ -94,22 +94,21 @@ Use when the Azure API supports both listing and searching: Choose the appropriate base struct based on Azure resource scope: -### `ResourceGroupBase` - Resource Group Scoped Resources +### `MultiResourceGroupBase` - Resource Group Scoped Resources (multi-scope) -For resources scoped to a specific resource group: +Resource-group-scoped adapters use **one adapter per resource type** that holds a slice of resource-group scopes. The engine calls List/Get once per scope; scope is resolved in each method via `ResourceGroupScopeFromScope(scope)`. ```go type computeVirtualMachineWrapper struct { client clients.VirtualMachinesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewComputeVirtualMachine(client clients.VirtualMachinesClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewComputeVirtualMachine(client clients.VirtualMachinesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &computeVirtualMachineWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_COMPUTE_APPLICATION, azureshared.ComputeVirtualMachine, ), @@ -117,6 +116,16 @@ func NewComputeVirtualMachine(client clients.VirtualMachinesClient, subscription } ``` +In Get/List/ListStream/Search, resolve scope to resource group (and subscription when needed) with: + +```go +rgScope, err := c.ResourceGroupScopeFromScope(scope) +if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) +} +// use rgScope.ResourceGroup and rgScope.SubscriptionID +``` + **Examples:** Compute Virtual Machines, Compute Disks, Network Interfaces, Network Security Groups ### `SubscriptionBase` - Subscription-Level Resources @@ -448,14 +457,14 @@ if err != nil { ### Pager Pattern -Handle Azure SDK pagers consistently: +Handle Azure SDK pagers consistently. For multi-scope resource group adapters, resolve scope first: ```go -resourceGroup := azureshared.ResourceGroupFromScope(scope) -if resourceGroup == "" { - resourceGroup = c.ResourceGroup() +rgScope, err := c.ResourceGroupScopeFromScope(scope) +if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } -pager := c.client.NewListPager(resourceGroup, nil) +pager := c.client.NewListPager(rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) if err != nil { @@ -501,7 +510,7 @@ func TestComputeVirtualMachine(t *testing.T) { resourceGroup := "test-resource-group" t.Run("Get", func(t *testing.T) { - wrapper := manual.NewComputeVirtualMachine(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachine(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) vm := createAzureVirtualMachine("test-vm", "Succeeded") mockClient.EXPECT().Get(ctx, resourceGroup, "test-vm", nil).Return(armcompute.VirtualMachinesClientGetResponse{ VirtualMachine: *vm, @@ -708,26 +717,33 @@ if resourceID != nil && *resourceID != "" { ### 8. Adapter Registration -Register adapters in the main adapters file with proper scoping: +Resource-group-scoped adapters are registered once per type with a slice of all resource group scopes. Build `resourceGroupScopes` after discovering resource groups, then pass it to each constructor: ```go -// Register resource group scoped adapters -for _, resourceGroup := range resourceGroups { +// Build resource group scopes from discovered resource groups +resourceGroupScopes := make([]azureshared.ResourceGroupScope, 0, len(resourceGroups)) +for _, rg := range resourceGroups { + resourceGroupScopes = append(resourceGroupScopes, azureshared.NewResourceGroupScope(subscriptionID, rg)) +} + +// Multi-scope resource group adapters (one adapter per type handling all resource groups) +if len(resourceGroupScopes) > 0 { adapters = append(adapters, sources.WrapperToAdapter(NewComputeVirtualMachine( clients.NewVirtualMachinesClient(vmClient), - subscriptionID, - resourceGroup, - )), + resourceGroupScopes, + ), cache), + sources.WrapperToAdapter(NewStorageAccount(..., resourceGroupScopes), cache), + // ... one line per resource type (33 total) ) } -// Register subscription-level adapters +// Subscription-level adapters are registered separately with subscriptionID adapters = append(adapters, sources.WrapperToAdapter(NewSubscription( clients.NewSubscriptionClient(subClient), subscriptionID, - )), + ), cache), ) ``` @@ -759,7 +775,7 @@ Before submitting a new adapter, ensure: - [ ] File follows naming convention (`{resource-name}.go`) - [ ] Imports are properly organized and minimal - [ ] Wrapper type matches Azure API capabilities -- [ ] Base struct matches resource scope (ResourceGroup/Subscription) +- [ ] Base struct matches resource scope (MultiResourceGroupBase/Subscription) - [ ] All required methods implemented (IAMPermissions, PredefinedRole, PotentialLinks) - [ ] Terraform mappings are correct and include documentation URLs - [ ] Get/Search lookups match the resource's query parameters diff --git a/sources/azure/manual/adapters.go b/sources/azure/manual/adapters.go index 2734a4ca..1b1a0697 100644 --- a/sources/azure/manual/adapters.go +++ b/sources/azure/manual/adapters.go @@ -9,6 +9,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/batch/armbatch/v3" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v7" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cosmos/armcosmos" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault/v2" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v8" @@ -22,6 +23,7 @@ import ( "github.com/overmindtech/cli/sdpcache" "github.com/overmindtech/cli/sources" "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" ) // Adapters returns a slice of discovery.Adapter instances for Azure Source. @@ -67,6 +69,12 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred "ovm.source.resource_group_count": len(resourceGroups), }).Info("Discovered resource groups") + // Build resource group scopes for multi-scope adapters + resourceGroupScopes := make([]azureshared.ResourceGroupScope, 0, len(resourceGroups)) + for _, rg := range resourceGroups { + resourceGroupScopes = append(resourceGroupScopes, azureshared.NewResourceGroupScope(subscriptionID, rg)) + } + // Initialize Azure SDK clients vmClient, err := armcompute.NewVirtualMachinesClient(subscriptionID, cred, nil) if err != nil { @@ -226,278 +234,145 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred return nil, fmt.Errorf("failed to create proximity placement groups client: %w", err) } - // Create adapters for each resource group - for _, resourceGroup := range resourceGroups { - // Add Compute Virtual Machine adapter for this resource group + zonesClient, err := armdns.NewZonesClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create zones client: %w", err) + } + + // Multi-scope resource group adapters (one adapter per type handling all resource groups) + if len(resourceGroupScopes) > 0 { adapters = append(adapters, sources.WrapperToAdapter(NewComputeVirtualMachine( clients.NewVirtualMachinesClient(vmClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add Storage Account adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewStorageAccount( clients.NewStorageAccountsClient(storageAccountsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Storage Blob Container adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewStorageBlobContainer( clients.NewBlobContainersClient(blobContainersClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add Storage File Share adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewStorageFileShare( clients.NewFileSharesClient(fileSharesClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add Storage Queue adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewStorageQueues( clients.NewQueuesClient(queuesClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Storage Table adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewStorageTable( clients.NewTablesClient(tablesClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add Network Virtual Network adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewNetworkVirtualNetwork( clients.NewVirtualNetworksClient(virtualNetworksClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add Network Network Interface adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewNetworkNetworkInterface( clients.NewNetworkInterfacesClient(networkInterfacesClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add SQL Database adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewSqlDatabase( clients.NewSqlDatabasesClient(sqlDatabasesClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add DocumentDB Database Account adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewDocumentDBDatabaseAccounts( clients.NewDocumentDBDatabaseAccountsClient(documentDBDatabaseAccountsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add Key Vault Vault adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewKeyVaultVault( clients.NewVaultsClient(keyVaultsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add Key Vault Managed HSM adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewKeyVaultManagedHSM( clients.NewManagedHSMsClient(managedHSMsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add PostgreSQL Database adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewDBforPostgreSQLDatabase( clients.NewPostgreSQLDatabasesClient(postgreSQLDatabasesClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add Network Public IP Address adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewNetworkPublicIPAddress( clients.NewPublicIPAddressesClient(publicIPAddressesClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add Network Load Balancer adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewNetworkLoadBalancer( clients.NewLoadBalancersClient(loadBalancersClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, + ), cache), + sources.WrapperToAdapter(NewNetworkZone( + clients.NewZonesClient(zonesClient), + resourceGroupScopes, ), cache), - ) - // Add Batch Account adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewBatchAccount( clients.NewBatchAccountsClient(batchAccountsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Virtual Machine Scale Set adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewComputeVirtualMachineScaleSet( clients.NewVirtualMachineScaleSetsClient(virtualMachineScaleSetsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Availability Set adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewComputeAvailabilitySet( clients.NewAvailabilitySetsClient(availabilitySetsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Disk adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewComputeDisk( clients.NewDisksClient(disksClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Network Security Group adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewNetworkNetworkSecurityGroup( clients.NewNetworkSecurityGroupsClient(networkSecurityGroupsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Network Route Table adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewNetworkRouteTable( clients.NewRouteTablesClient(routeTablesClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add Network Application Gateway adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewNetworkApplicationGateway( clients.NewApplicationGatewaysClient(applicationGatewaysClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add SQL Server adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewSqlServer( clients.NewSqlServersClient(sqlServersClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add PostgreSQL Flexible Server adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewDBforPostgreSQLFlexibleServer( clients.NewPostgreSQLFlexibleServersClient(postgresqlFlexibleServersClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add Key Vault Secret adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewKeyVaultSecret( clients.NewSecretsClient(secretsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - - // Add User Assigned Identity adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewManagedIdentityUserAssignedIdentity( clients.NewUserAssignedIdentitiesClient(userAssignedIdentitiesClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Role Assignment adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewAuthorizationRoleAssignment( clients.NewRoleAssignmentsClient(roleAssignmentsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Disk Encryption Set adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewComputeDiskEncryptionSet( clients.NewDiskEncryptionSetsClient(diskEncryptionSetsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Image adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewComputeImage( clients.NewImagesClient(imagesClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Virtual Machine Run Command adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewComputeVirtualMachineRunCommand( clients.NewVirtualMachineRunCommandsClient(virtualMachineRunCommandsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Virtual Machine Extension adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewComputeVirtualMachineExtension( clients.NewVirtualMachineExtensionsClient(virtualMachineExtensionsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), - ) - // Add Proximity Placement Group adapter for this resource group - adapters = append(adapters, sources.WrapperToAdapter(NewComputeProximityPlacementGroup( clients.NewProximityPlacementGroupsClient(proximityPlacementGroupsClient), - subscriptionID, - resourceGroup, + resourceGroupScopes, ), cache), ) } @@ -510,168 +385,43 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred } else { // For metadata registration only - no actual clients needed // This is used to enumerate available adapter types for documentation - // Create placeholder adapters with nil clients for metadata registration + // Create placeholder adapters with nil clients and one placeholder scope + placeholderResourceGroupScopes := []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, "placeholder-resource-group")} + noOpCache := sdpcache.NewNoOpCache() adapters = append(adapters, - sources.WrapperToAdapter(NewComputeVirtualMachine( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewStorageAccount( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewStorageBlobContainer( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewStorageFileShare( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewStorageQueues( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewStorageTable( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewNetworkVirtualNetwork( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewNetworkNetworkInterface( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewSqlDatabase( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewDocumentDBDatabaseAccounts( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewKeyVaultVault( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewDBforPostgreSQLDatabase( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewNetworkPublicIPAddress( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewNetworkLoadBalancer( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewBatchAccount( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewComputeVirtualMachineScaleSet( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewComputeAvailabilitySet( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewComputeDisk( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewNetworkNetworkSecurityGroup( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewNetworkRouteTable( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewNetworkApplicationGateway( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewKeyVaultManagedHSM( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewSqlServer( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewDBforPostgreSQLFlexibleServer( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewKeyVaultSecret( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewManagedIdentityUserAssignedIdentity( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewAuthorizationRoleAssignment( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewComputeDiskEncryptionSet( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewComputeImage( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewComputeVirtualMachineRunCommand( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewComputeVirtualMachineExtension( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration - sources.WrapperToAdapter(NewComputeProximityPlacementGroup( - nil, // nil client is okay for metadata registration - subscriptionID, - "placeholder-resource-group", - ), sdpcache.NewNoOpCache()), // no-op cache for metadata registration + sources.WrapperToAdapter(NewComputeVirtualMachine(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewStorageAccount(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewStorageBlobContainer(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewStorageFileShare(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewStorageQueues(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewStorageTable(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkVirtualNetwork(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkNetworkInterface(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewSqlDatabase(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewDocumentDBDatabaseAccounts(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewKeyVaultVault(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewKeyVaultManagedHSM(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewDBforPostgreSQLDatabase(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkPublicIPAddress(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkLoadBalancer(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkZone(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewBatchAccount(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewComputeVirtualMachineScaleSet(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewComputeAvailabilitySet(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewComputeDisk(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkNetworkSecurityGroup(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkRouteTable(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkApplicationGateway(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewSqlServer(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewDBforPostgreSQLFlexibleServer(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewKeyVaultSecret(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewManagedIdentityUserAssignedIdentity(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewAuthorizationRoleAssignment(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewComputeDiskEncryptionSet(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewComputeImage(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewComputeVirtualMachineRunCommand(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewComputeVirtualMachineExtension(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewComputeProximityPlacementGroup(nil, placeholderResourceGroupScopes), noOpCache), ) _ = regions diff --git a/sources/azure/manual/authorization-role-assignment.go b/sources/azure/manual/authorization-role-assignment.go index 13d4b734..66f6f6eb 100644 --- a/sources/azure/manual/authorization-role-assignment.go +++ b/sources/azure/manual/authorization-role-assignment.go @@ -11,6 +11,8 @@ import ( "github.com/overmindtech/cli/sources/azure/clients" azureshared "github.com/overmindtech/cli/sources/azure/shared" "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sdpcache" + "github.com/overmindtech/cli/discovery" ) var AuthorizationRoleAssignmentLookupByName = shared.NewItemTypeLookup("name", azureshared.AuthorizationRoleAssignment) @@ -18,15 +20,14 @@ var AuthorizationRoleAssignmentLookupByName = shared.NewItemTypeLookup("name", a type authorizationRoleAssignmentWrapper struct { client clients.RoleAssignmentsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewAuthorizationRoleAssignment(client clients.RoleAssignmentsClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewAuthorizationRoleAssignment(client clients.RoleAssignmentsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &authorizationRoleAssignmentWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_SECURITY, azureshared.AuthorizationRoleAssignment, ), @@ -37,11 +38,11 @@ func (a authorizationRoleAssignmentWrapper) List(ctx context.Context, scope stri if scope == "" { return nil, azureshared.QueryError(errors.New("scope cannot be empty"), scope, a.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = a.ResourceGroup() + rgScope, err := a.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, a.Type()) } - pager := a.client.ListForResourceGroup(resourceGroup, nil) + pager := a.client.ListForResourceGroup(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -61,6 +62,32 @@ func (a authorizationRoleAssignmentWrapper) List(ctx context.Context, scope stri return items, nil } + +func (a authorizationRoleAssignmentWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := a.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, a.Type())) + return + } + pager := a.client.ListForResourceGroup(rgScope.ResourceGroup, nil) + + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, a.Type())) + return + } + for _, roleAssignment := range page.Value { + item, sdpErr := a.azureRoleAssignmentToSDPItem(roleAssignment, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} func (a authorizationRoleAssignmentWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { if scope == "" { return nil, azureshared.QueryError(errors.New("scope cannot be empty"), scope, a.Type()) @@ -74,8 +101,12 @@ func (a authorizationRoleAssignmentWrapper) Get(ctx context.Context, scope strin return nil, azureshared.QueryError(errors.New("roleAssignmentName cannot be empty"), scope, a.Type()) } + rgScope, err := a.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, a.Type()) + } // Construct the Azure scope path from either subscription ID or resource group name - azureScope := azureshared.ConstructRoleAssignmentScope(scope, a.SubscriptionID()) + azureScope := azureshared.ConstructRoleAssignmentScope(scope, rgScope.SubscriptionID) if azureScope == "" { return nil, azureshared.QueryError(errors.New("failed to construct Azure scope path"), scope, a.Type()) } @@ -104,11 +135,11 @@ func (a authorizationRoleAssignmentWrapper) azureRoleAssignmentToSDPItem(roleAss return nil, azureshared.QueryError(errors.New("role assignment name cannot be empty"), scope, a.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = a.ResourceGroup() + rgScope, err := a.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, a.Type()) } - err = attributes.Set("uniqueAttr", shared.CompositeLookupKey(resourceGroup, roleAssignmentName)) + err = attributes.Set("uniqueAttr", shared.CompositeLookupKey(rgScope.ResourceGroup, roleAssignmentName)) if err != nil { return nil, azureshared.QueryError(err, scope, a.Type()) } diff --git a/sources/azure/manual/authorization-role-assignment_test.go b/sources/azure/manual/authorization-role-assignment_test.go index 3b3e4ae0..1635cdbe 100644 --- a/sources/azure/manual/authorization-role-assignment_test.go +++ b/sources/azure/manual/authorization-role-assignment_test.go @@ -40,7 +40,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { RoleAssignment: *roleAssignment, }, nil) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, scope, roleAssignmentName, true) @@ -88,7 +88,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { t.Run("Get_EmptyScope", func(t *testing.T) { mockClient := mocks.NewMockRoleAssignmentsClient(ctrl) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, "", "test-role-assignment", true) @@ -100,7 +100,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockRoleAssignmentsClient(ctrl) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with insufficient query parts (empty) @@ -119,7 +119,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { t.Run("Get_EmptyRoleAssignmentName", func(t *testing.T) { mockClient := mocks.NewMockRoleAssignmentsClient(ctrl) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, scope, "", true) @@ -138,7 +138,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { armauthorization.RoleAssignmentsClientGetResponse{}, expectedError) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, scope, roleAssignmentName, true) @@ -163,7 +163,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { RoleAssignment: *roleAssignment, }, nil) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, scope, roleAssignmentName, true) @@ -193,7 +193,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { mockClient.EXPECT().ListForResourceGroup(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -230,7 +230,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { t.Run("List_EmptyScope", func(t *testing.T) { mockClient := mocks.NewMockRoleAssignmentsClient(ctrl) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -260,7 +260,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { mockClient.EXPECT().ListForResourceGroup(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -300,7 +300,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { mockClient.EXPECT().ListForResourceGroup(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -316,7 +316,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { t.Run("GetLookups", func(t *testing.T) { mockClient := mocks.NewMockRoleAssignmentsClient(ctrl) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) lookups := wrapper.GetLookups() if len(lookups) != 1 { @@ -337,7 +337,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { t.Run("TerraformMappings", func(t *testing.T) { mockClient := mocks.NewMockRoleAssignmentsClient(ctrl) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) mappings := wrapper.TerraformMappings() if len(mappings) == 0 { @@ -362,7 +362,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { t.Run("PotentialLinks", func(t *testing.T) { mockClient := mocks.NewMockRoleAssignmentsClient(ctrl) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) potentialLinks := wrapper.PotentialLinks() if len(potentialLinks) == 0 { @@ -378,7 +378,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { t.Run("IAMPermissions", func(t *testing.T) { mockClient := mocks.NewMockRoleAssignmentsClient(ctrl) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) permissions := wrapper.IAMPermissions() if len(permissions) != 1 { @@ -393,7 +393,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { t.Run("PredefinedRole", func(t *testing.T) { mockClient := mocks.NewMockRoleAssignmentsClient(ctrl) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Use interface assertion to access PredefinedRole method if roleInterface, ok := interface{}(wrapper).(interface{ PredefinedRole() string }); ok { @@ -420,7 +420,7 @@ func TestAuthorizationRoleAssignment(t *testing.T) { RoleAssignment: *roleAssignment, }, nil) - wrapper := manual.NewAuthorizationRoleAssignment(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewAuthorizationRoleAssignment(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, scope, roleAssignmentName, true) diff --git a/sources/azure/manual/batch-batch-accounts.go b/sources/azure/manual/batch-batch-accounts.go index 22461fe0..cb78de1b 100644 --- a/sources/azure/manual/batch-batch-accounts.go +++ b/sources/azure/manual/batch-batch-accounts.go @@ -11,6 +11,8 @@ import ( azureshared "github.com/overmindtech/cli/sources/azure/shared" "github.com/overmindtech/cli/sources/shared" "github.com/overmindtech/cli/sources/stdlib" + "github.com/overmindtech/cli/sdpcache" + "github.com/overmindtech/cli/discovery" ) var BatchAccountLookupByName = shared.NewItemTypeLookup("name", azureshared.BatchBatchAccount) @@ -18,15 +20,14 @@ var BatchAccountLookupByName = shared.NewItemTypeLookup("name", azureshared.Batc type batchAccountWrapper struct { client clients.BatchAccountsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewBatchAccount(client clients.BatchAccountsClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewBatchAccount(client clients.BatchAccountsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &batchAccountWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_COMPUTE_APPLICATION, azureshared.BatchBatchAccount, ), @@ -34,11 +35,11 @@ func NewBatchAccount(client clients.BatchAccountsClient, subscriptionID, resourc } func (b batchAccountWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = b.ResourceGroup() + rgScope, err := b.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, b.Type()) } - pager := b.client.ListByResourceGroup(ctx, resourceGroup) + pager := b.client.ListByResourceGroup(ctx, rgScope.ResourceGroup) var items []*sdp.Item for pager.More() { @@ -63,6 +64,33 @@ func (b batchAccountWrapper) List(ctx context.Context, scope string) ([]*sdp.Ite return items, nil } +func (b batchAccountWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := b.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, b.Type())) + return + } + pager := b.client.ListByResourceGroup(ctx, rgScope.ResourceGroup) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, b.Type())) + return + } + for _, account := range page.Value { + if account.Name == nil { + continue + } + item, sdpErr := b.azureBatchAccountToSDPItem(account, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} func (b batchAccountWrapper) azureBatchAccountToSDPItem(account *armbatch.Account, scope string) (*sdp.Item, *sdp.QueryError) { if account.Name == nil { return nil, azureshared.QueryError(errors.New("name is nil"), scope, b.Type()) @@ -449,11 +477,11 @@ func (b batchAccountWrapper) Get(ctx context.Context, scope string, queryParts . return nil, azureshared.QueryError(errors.New("accountName is empty"), scope, b.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = b.ResourceGroup() + rgScope, err := b.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, b.Type()) } - resp, err := b.client.Get(ctx, resourceGroup, accountName) + resp, err := b.client.Get(ctx, rgScope.ResourceGroup, accountName) if err != nil { return nil, azureshared.QueryError(err, scope, b.Type()) } diff --git a/sources/azure/manual/batch-batch-accounts_test.go b/sources/azure/manual/batch-batch-accounts_test.go index ba7c06dd..bd9d37a8 100644 --- a/sources/azure/manual/batch-batch-accounts_test.go +++ b/sources/azure/manual/batch-batch-accounts_test.go @@ -55,7 +55,7 @@ func TestBatchAccount(t *testing.T) { Account: *account, }, nil) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], accountName, true) @@ -211,7 +211,7 @@ func TestBatchAccount(t *testing.T) { t.Run("Get_EmptyAccountName", func(t *testing.T) { mockClient := mocks.NewMockBatchAccountsClient(ctrl) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "", true) @@ -223,7 +223,7 @@ func TestBatchAccount(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockBatchAccountsClient(ctrl) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with no query parts @@ -241,7 +241,7 @@ func TestBatchAccount(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, accountName).Return( armbatch.AccountClientGetResponse{}, expectedErr) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], accountName, true) @@ -267,7 +267,7 @@ func TestBatchAccount(t *testing.T) { mockClient.EXPECT().ListByResourceGroup(ctx, resourceGroup).Return(mockPager) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -313,7 +313,7 @@ func TestBatchAccount(t *testing.T) { mockClient.EXPECT().ListByResourceGroup(ctx, resourceGroup).Return(mockPager) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -344,7 +344,7 @@ func TestBatchAccount(t *testing.T) { mockClient.EXPECT().ListByResourceGroup(ctx, resourceGroup).Return(mockPager) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -360,7 +360,7 @@ func TestBatchAccount(t *testing.T) { t.Run("GetLookups", func(t *testing.T) { mockClient := mocks.NewMockBatchAccountsClient(ctrl) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) lookups := wrapper.GetLookups() if len(lookups) != 1 { @@ -374,7 +374,7 @@ func TestBatchAccount(t *testing.T) { t.Run("PotentialLinks", func(t *testing.T) { mockClient := mocks.NewMockBatchAccountsClient(ctrl) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) potentialLinks := wrapper.PotentialLinks() expectedLinks := []shared.ItemType{ @@ -399,7 +399,7 @@ func TestBatchAccount(t *testing.T) { t.Run("TerraformMappings", func(t *testing.T) { mockClient := mocks.NewMockBatchAccountsClient(ctrl) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) mappings := wrapper.TerraformMappings() if len(mappings) != 1 { @@ -417,7 +417,7 @@ func TestBatchAccount(t *testing.T) { t.Run("IAMPermissions", func(t *testing.T) { mockClient := mocks.NewMockBatchAccountsClient(ctrl) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) permissions := wrapper.IAMPermissions() expectedPermissions := []string{ @@ -437,7 +437,7 @@ func TestBatchAccount(t *testing.T) { t.Run("PredefinedRole", func(t *testing.T) { mockClient := mocks.NewMockBatchAccountsClient(ctrl) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // PredefinedRole is available on the wrapper, not the adapter role := wrapper.(interface{ PredefinedRole() string }).PredefinedRole() @@ -466,7 +466,7 @@ func TestBatchAccount(t *testing.T) { Account: *account, }, nil) - wrapper := manual.NewBatchAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewBatchAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], accountName, true) diff --git a/sources/azure/manual/compute-availability-set.go b/sources/azure/manual/compute-availability-set.go index c52aa6dd..8fed33fb 100644 --- a/sources/azure/manual/compute-availability-set.go +++ b/sources/azure/manual/compute-availability-set.go @@ -19,15 +19,14 @@ var ComputeAvailabilitySetLookupByName = shared.NewItemTypeLookup("name", azures type computeAvailabilitySetWrapper struct { client clients.AvailabilitySetsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewComputeAvailabilitySet(client clients.AvailabilitySetsClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewComputeAvailabilitySet(client clients.AvailabilitySetsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &computeAvailabilitySetWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_COMPUTE_APPLICATION, azureshared.ComputeAvailabilitySet, ), @@ -36,11 +35,11 @@ func NewComputeAvailabilitySet(client clients.AvailabilitySetsClient, subscripti // ref: https://learn.microsoft.com/en-us/rest/api/compute/availability-sets/list?view=rest-compute-2025-04-01&tabs=HTTP func (c computeAvailabilitySetWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - pager := c.client.NewListPager(resourceGroup, nil) + pager := c.client.NewListPager(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -64,11 +63,12 @@ func (c computeAvailabilitySetWrapper) List(ctx context.Context, scope string) ( } func (c computeAvailabilitySetWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return } - pager := c.client.NewListPager(resourceGroup, nil) + pager := c.client.NewListPager(rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) @@ -104,11 +104,11 @@ func (c computeAvailabilitySetWrapper) Get(ctx context.Context, scope string, qu return nil, azureshared.QueryError(errors.New("availabilitySetName cannot be empty"), scope, c.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - availabilitySet, err := c.client.Get(ctx, resourceGroup, availabilitySetName, nil) + availabilitySet, err := c.client.Get(ctx, rgScope.ResourceGroup, availabilitySetName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, c.Type()) } diff --git a/sources/azure/manual/compute-availability-set_test.go b/sources/azure/manual/compute-availability-set_test.go index d934e79b..ad5cd317 100644 --- a/sources/azure/manual/compute-availability-set_test.go +++ b/sources/azure/manual/compute-availability-set_test.go @@ -39,7 +39,7 @@ func TestComputeAvailabilitySet(t *testing.T) { AvailabilitySet: *avSet, }, nil) - wrapper := manual.NewComputeAvailabilitySet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeAvailabilitySet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], availabilitySetName, true) @@ -114,7 +114,7 @@ func TestComputeAvailabilitySet(t *testing.T) { AvailabilitySet: *avSet, }, nil) - wrapper := manual.NewComputeAvailabilitySet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeAvailabilitySet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], availabilitySetName, true) @@ -160,7 +160,7 @@ func TestComputeAvailabilitySet(t *testing.T) { AvailabilitySet: *avSet, }, nil) - wrapper := manual.NewComputeAvailabilitySet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeAvailabilitySet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], availabilitySetName, true) @@ -182,7 +182,7 @@ func TestComputeAvailabilitySet(t *testing.T) { mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeAvailabilitySet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeAvailabilitySet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -219,7 +219,7 @@ func TestComputeAvailabilitySet(t *testing.T) { mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeAvailabilitySet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeAvailabilitySet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) wg := &sync.WaitGroup{} @@ -277,7 +277,7 @@ func TestComputeAvailabilitySet(t *testing.T) { mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeAvailabilitySet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeAvailabilitySet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -303,7 +303,7 @@ func TestComputeAvailabilitySet(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-avset", nil).Return( armcompute.AvailabilitySetsClientGetResponse{}, expectedErr) - wrapper := manual.NewComputeAvailabilitySet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeAvailabilitySet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-avset", true) @@ -315,7 +315,7 @@ func TestComputeAvailabilitySet(t *testing.T) { t.Run("GetWithEmptyName", func(t *testing.T) { mockClient := mocks.NewMockAvailabilitySetsClient(ctrl) - wrapper := manual.NewComputeAvailabilitySet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeAvailabilitySet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "", true) @@ -327,7 +327,7 @@ func TestComputeAvailabilitySet(t *testing.T) { t.Run("GetWithInsufficientQueryParts", func(t *testing.T) { mockClient := mocks.NewMockAvailabilitySetsClient(ctrl) - wrapper := manual.NewComputeAvailabilitySet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeAvailabilitySet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Test the wrapper's Get method directly with insufficient query parts _, qErr := wrapper.Get(ctx, wrapper.Scopes()[0]) if qErr == nil { diff --git a/sources/azure/manual/compute-disk-encryption-set.go b/sources/azure/manual/compute-disk-encryption-set.go index 88747847..c1d19b8b 100644 --- a/sources/azure/manual/compute-disk-encryption-set.go +++ b/sources/azure/manual/compute-disk-encryption-set.go @@ -19,15 +19,14 @@ var ComputeDiskEncryptionSetLookupByName = shared.NewItemTypeLookup("name", azur type computeDiskEncryptionSetWrapper struct { client clients.DiskEncryptionSetsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewComputeDiskEncryptionSet(client clients.DiskEncryptionSetsClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewComputeDiskEncryptionSet(client clients.DiskEncryptionSetsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &computeDiskEncryptionSetWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_STORAGE, azureshared.ComputeDiskEncryptionSet, ), @@ -35,11 +34,11 @@ func NewComputeDiskEncryptionSet(client clients.DiskEncryptionSetsClient, subscr } func (c computeDiskEncryptionSetWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - pager := c.client.NewListByResourceGroupPager(resourceGroup, nil) + pager := c.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -62,11 +61,12 @@ func (c computeDiskEncryptionSetWrapper) List(ctx context.Context, scope string) } func (c computeDiskEncryptionSetWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return } - pager := c.client.NewListByResourceGroupPager(resourceGroup, nil) + pager := c.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) if err != nil { @@ -92,15 +92,15 @@ func (c computeDiskEncryptionSetWrapper) Get(ctx context.Context, scope string, if len(queryParts) < 1 { return nil, azureshared.QueryError(errors.New("queryParts must be at least 1 and be the disk encryption set name"), scope, c.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } diskEncryptionSetName := queryParts[0] if diskEncryptionSetName == "" { return nil, azureshared.QueryError(errors.New("diskEncryptionSetName cannot be empty"), scope, c.Type()) } - diskEncryptionSet, err := c.client.Get(ctx, resourceGroup, diskEncryptionSetName, nil) + diskEncryptionSet, err := c.client.Get(ctx, rgScope.ResourceGroup, diskEncryptionSetName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, c.Type()) } diff --git a/sources/azure/manual/compute-disk-encryption-set_test.go b/sources/azure/manual/compute-disk-encryption-set_test.go index 957eb5b1..f243fbf1 100644 --- a/sources/azure/manual/compute-disk-encryption-set_test.go +++ b/sources/azure/manual/compute-disk-encryption-set_test.go @@ -40,7 +40,7 @@ func TestComputeDiskEncryptionSet(t *testing.T) { nil, ) - wrapper := manual.NewComputeDiskEncryptionSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDiskEncryptionSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], desName, true) @@ -72,7 +72,7 @@ func TestComputeDiskEncryptionSet(t *testing.T) { nil, ) - wrapper := manual.NewComputeDiskEncryptionSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDiskEncryptionSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], desName, true) @@ -140,7 +140,7 @@ func TestComputeDiskEncryptionSet(t *testing.T) { nil, ) - wrapper := manual.NewComputeDiskEncryptionSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDiskEncryptionSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], desName, true) @@ -241,7 +241,7 @@ func TestComputeDiskEncryptionSet(t *testing.T) { nil, ) - wrapper := manual.NewComputeDiskEncryptionSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDiskEncryptionSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], desName, true) @@ -279,7 +279,7 @@ func TestComputeDiskEncryptionSet(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockDiskEncryptionSetsClient(ctrl) - wrapper := manual.NewComputeDiskEncryptionSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDiskEncryptionSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "", true) @@ -290,7 +290,7 @@ func TestComputeDiskEncryptionSet(t *testing.T) { t.Run("WrapperGet_MissingQueryParts", func(t *testing.T) { mockClient := mocks.NewMockDiskEncryptionSetsClient(ctrl) - wrapper := manual.NewComputeDiskEncryptionSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDiskEncryptionSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) defer func() { if r := recover(); r != nil { @@ -317,7 +317,7 @@ func TestComputeDiskEncryptionSet(t *testing.T) { nil, ) - wrapper := manual.NewComputeDiskEncryptionSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDiskEncryptionSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], desName, true) @@ -334,7 +334,7 @@ func TestComputeDiskEncryptionSet(t *testing.T) { mockPager := newMockDiskEncryptionSetsPager(ctrl, []*armcompute.DiskEncryptionSet{des1, des2}) mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeDiskEncryptionSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDiskEncryptionSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -362,7 +362,7 @@ func TestComputeDiskEncryptionSet(t *testing.T) { mockPager := newMockDiskEncryptionSetsPager(ctrl, []*armcompute.DiskEncryptionSet{des1, desNil}) mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeDiskEncryptionSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDiskEncryptionSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -384,7 +384,7 @@ func TestComputeDiskEncryptionSet(t *testing.T) { mockPager := newErrorDiskEncryptionSetsPager(ctrl) mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeDiskEncryptionSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDiskEncryptionSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -409,7 +409,7 @@ func TestComputeDiskEncryptionSet(t *testing.T) { mockPager := newMockDiskEncryptionSetsPager(ctrl, []*armcompute.DiskEncryptionSet{des1, des2}) mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeDiskEncryptionSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDiskEncryptionSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) wg := &sync.WaitGroup{} @@ -442,7 +442,7 @@ func TestComputeDiskEncryptionSet(t *testing.T) { mockPager := newErrorDiskEncryptionSetsPager(ctrl) mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeDiskEncryptionSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDiskEncryptionSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) errCh := make(chan error, 1) diff --git a/sources/azure/manual/compute-disk.go b/sources/azure/manual/compute-disk.go index 78915871..79ba209b 100644 --- a/sources/azure/manual/compute-disk.go +++ b/sources/azure/manual/compute-disk.go @@ -20,15 +20,14 @@ var ComputeDiskLookupByName = shared.NewItemTypeLookup("name", azureshared.Compu type computeDiskWrapper struct { client clients.DisksClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewComputeDisk(client clients.DisksClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewComputeDisk(client clients.DisksClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &computeDiskWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_STORAGE, azureshared.ComputeDisk, ), @@ -36,11 +35,11 @@ func NewComputeDisk(client clients.DisksClient, subscriptionID, resourceGroup st } func (c computeDiskWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - pager := c.client.NewListByResourceGroupPager(resourceGroup, nil) + pager := c.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -63,11 +62,12 @@ func (c computeDiskWrapper) List(ctx context.Context, scope string) ([]*sdp.Item } func (c computeDiskWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return } - pager := c.client.NewListByResourceGroupPager(resourceGroup, nil) + pager := c.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) if err != nil { @@ -94,11 +94,11 @@ func (c computeDiskWrapper) Get(ctx context.Context, scope string, queryParts .. return nil, azureshared.QueryError(errors.New("queryParts must be at least 1 and be the disk name"), scope, c.Type()) } diskName := queryParts[0] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - disk, err := c.client.Get(ctx, resourceGroup, diskName, nil) + disk, err := c.client.Get(ctx, rgScope.ResourceGroup, diskName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, c.Type()) } diff --git a/sources/azure/manual/compute-disk_test.go b/sources/azure/manual/compute-disk_test.go index b2d0de25..e610f146 100644 --- a/sources/azure/manual/compute-disk_test.go +++ b/sources/azure/manual/compute-disk_test.go @@ -39,7 +39,7 @@ func TestComputeDisk(t *testing.T) { Disk: *disk, }, nil) - wrapper := manual.NewComputeDisk(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDisk(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], diskName, true) @@ -74,7 +74,7 @@ func TestComputeDisk(t *testing.T) { Disk: *disk, }, nil) - wrapper := manual.NewComputeDisk(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDisk(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], diskName, true) @@ -309,7 +309,7 @@ func TestComputeDisk(t *testing.T) { Disk: *disk, }, nil) - wrapper := manual.NewComputeDisk(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDisk(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], diskName, true) @@ -347,7 +347,7 @@ func TestComputeDisk(t *testing.T) { Disk: *disk, }, nil) - wrapper := manual.NewComputeDisk(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDisk(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], diskName, true) @@ -382,7 +382,7 @@ func TestComputeDisk(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeDisk(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDisk(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -419,7 +419,7 @@ func TestComputeDisk(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeDisk(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDisk(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) wg := &sync.WaitGroup{} @@ -477,7 +477,7 @@ func TestComputeDisk(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeDisk(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDisk(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -503,7 +503,7 @@ func TestComputeDisk(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-disk", nil).Return( armcompute.DisksClientGetResponse{}, expectedErr) - wrapper := manual.NewComputeDisk(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDisk(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-disk", true) @@ -518,7 +518,7 @@ func TestComputeDisk(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "", nil).Return( armcompute.DisksClientGetResponse{}, expectedErr) - wrapper := manual.NewComputeDisk(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDisk(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "", true) @@ -530,7 +530,7 @@ func TestComputeDisk(t *testing.T) { t.Run("GetWithInsufficientQueryParts", func(t *testing.T) { mockClient := mocks.NewMockDisksClient(ctrl) - wrapper := manual.NewComputeDisk(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDisk(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Test the wrapper's Get method directly with insufficient query parts _, qErr := wrapper.Get(ctx, wrapper.Scopes()[0]) if qErr == nil { @@ -544,7 +544,7 @@ func TestComputeDisk(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(errorPager) - wrapper := manual.NewComputeDisk(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDisk(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -564,7 +564,7 @@ func TestComputeDisk(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(errorPager) - wrapper := manual.NewComputeDisk(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeDisk(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) var errs []error diff --git a/sources/azure/manual/compute-image.go b/sources/azure/manual/compute-image.go index ac40ad5b..86bdf6c2 100644 --- a/sources/azure/manual/compute-image.go +++ b/sources/azure/manual/compute-image.go @@ -20,15 +20,14 @@ var ComputeImageLookupByName = shared.NewItemTypeLookup("name", azureshared.Comp type computeImageWrapper struct { client clients.ImagesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewComputeImage(client clients.ImagesClient, subscriptionID, resourceGroup string) sources.ListStreamableWrapper { +func NewComputeImage(client clients.ImagesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListStreamableWrapper { return &computeImageWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_COMPUTE_APPLICATION, azureshared.ComputeImage, ), @@ -36,11 +35,11 @@ func NewComputeImage(client clients.ImagesClient, subscriptionID, resourceGroup } func (c computeImageWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - pager := c.client.NewListByResourceGroupPager(resourceGroup, nil) + pager := c.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -63,11 +62,12 @@ func (c computeImageWrapper) List(ctx context.Context, scope string) ([]*sdp.Ite } func (c computeImageWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return } - pager := c.client.NewListByResourceGroupPager(resourceGroup, nil) + pager := c.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) if err != nil { @@ -94,11 +94,11 @@ func (c computeImageWrapper) Get(ctx context.Context, scope string, queryParts . return nil, azureshared.QueryError(errors.New("queryParts must be exactly 1 and be the image name"), scope, c.Type()) } imageName := queryParts[0] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - image, err := c.client.Get(ctx, resourceGroup, imageName, nil) + image, err := c.client.Get(ctx, rgScope.ResourceGroup, imageName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, c.Type()) } diff --git a/sources/azure/manual/compute-image_test.go b/sources/azure/manual/compute-image_test.go index 48f3f1c3..3ee1cfd2 100644 --- a/sources/azure/manual/compute-image_test.go +++ b/sources/azure/manual/compute-image_test.go @@ -40,7 +40,7 @@ func TestComputeImage(t *testing.T) { Image: *image, }, nil) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], imageName, true) @@ -75,7 +75,7 @@ func TestComputeImage(t *testing.T) { Image: *image, }, nil) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], imageName, true) @@ -244,7 +244,7 @@ func TestComputeImage(t *testing.T) { Image: *image, }, nil) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], imageName, true) @@ -279,7 +279,7 @@ func TestComputeImage(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -316,7 +316,7 @@ func TestComputeImage(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) wg := &sync.WaitGroup{} @@ -374,7 +374,7 @@ func TestComputeImage(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -400,7 +400,7 @@ func TestComputeImage(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-image", nil).Return( armcompute.ImagesClientGetResponse{}, expectedErr) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-image", true) @@ -412,7 +412,7 @@ func TestComputeImage(t *testing.T) { t.Run("GetWithInsufficientQueryParts", func(t *testing.T) { mockClient := mocks.NewMockImagesClient(ctrl) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Test the wrapper's Get method directly with insufficient query parts _, qErr := wrapper.Get(ctx, wrapper.Scopes()[0]) if qErr == nil { @@ -426,7 +426,7 @@ func TestComputeImage(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(errorPager) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -446,7 +446,7 @@ func TestComputeImage(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(errorPager) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) var errs []error @@ -470,7 +470,7 @@ func TestComputeImage(t *testing.T) { t.Run("GetLookups", func(t *testing.T) { mockClient := mocks.NewMockImagesClient(ctrl) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) lookups := wrapper.GetLookups() if len(lookups) != 1 { @@ -485,7 +485,7 @@ func TestComputeImage(t *testing.T) { t.Run("PotentialLinks", func(t *testing.T) { mockClient := mocks.NewMockImagesClient(ctrl) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) potentialLinks := wrapper.PotentialLinks() expectedLinks := []shared.ItemType{ @@ -507,7 +507,7 @@ func TestComputeImage(t *testing.T) { t.Run("TerraformMappings", func(t *testing.T) { mockClient := mocks.NewMockImagesClient(ctrl) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) mappings := wrapper.TerraformMappings() if len(mappings) != 1 { @@ -525,7 +525,7 @@ func TestComputeImage(t *testing.T) { t.Run("IAMPermissions", func(t *testing.T) { mockClient := mocks.NewMockImagesClient(ctrl) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) permissions := wrapper.IAMPermissions() expectedPermissions := []string{ @@ -545,7 +545,7 @@ func TestComputeImage(t *testing.T) { t.Run("PredefinedRole", func(t *testing.T) { mockClient := mocks.NewMockImagesClient(ctrl) - wrapper := manual.NewComputeImage(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeImage(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // PredefinedRole is available on the wrapper, not the adapter if roleInterface, ok := interface{}(wrapper).(interface{ PredefinedRole() string }); ok { diff --git a/sources/azure/manual/compute-proximity-placement-group.go b/sources/azure/manual/compute-proximity-placement-group.go index 8ab9d101..42029628 100644 --- a/sources/azure/manual/compute-proximity-placement-group.go +++ b/sources/azure/manual/compute-proximity-placement-group.go @@ -10,21 +10,22 @@ import ( "github.com/overmindtech/cli/sources/azure/clients" azureshared "github.com/overmindtech/cli/sources/azure/shared" "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sdpcache" + "github.com/overmindtech/cli/discovery" ) var ComputeProximityPlacementGroupLookupByName = shared.NewItemTypeLookup("name", azureshared.ComputeProximityPlacementGroup) type computeProximityPlacementGroupWrapper struct { client clients.ProximityPlacementGroupsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewComputeProximityPlacementGroup(client clients.ProximityPlacementGroupsClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewComputeProximityPlacementGroup(client clients.ProximityPlacementGroupsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &computeProximityPlacementGroupWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_COMPUTE_APPLICATION, azureshared.ComputeProximityPlacementGroup, ), @@ -33,11 +34,11 @@ func NewComputeProximityPlacementGroup(client clients.ProximityPlacementGroupsCl // ref: https://learn.microsoft.com/en-us/rest/api/compute/proximity-placement-groups/list-by-resource-group?view=rest-compute-2025-04-01&tabs=HTTP func (c computeProximityPlacementGroupWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - pager := c.client.ListByResourceGroup(ctx, resourceGroup, nil) + pager := c.client.ListByResourceGroup(ctx, rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { page, err := pager.NextPage(ctx) @@ -58,17 +59,45 @@ func (c computeProximityPlacementGroupWrapper) List(ctx context.Context, scope s return items, nil } +func (c computeProximityPlacementGroupWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + pager := c.client.ListByResourceGroup(ctx, rgScope.ResourceGroup, nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + for _, proximityPlacementGroup := range page.Value { + if proximityPlacementGroup.Name == nil { + continue + } + item, sdpErr := c.azureProximityPlacementGroupToSDPItem(proximityPlacementGroup, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + // ref: https://learn.microsoft.com/en-us/rest/api/compute/proximity-placement-groups/get?view=rest-compute-2025-04-01&tabs=HTTP func (c computeProximityPlacementGroupWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } if len(queryParts) < 1 { return nil, azureshared.QueryError(errors.New("queryParts must be at least 1 and be the proximity placement group name"), scope, c.Type()) } proximityPlacementGroupName := queryParts[0] - resp, err := c.client.Get(ctx, resourceGroup, proximityPlacementGroupName, nil) + resp, err := c.client.Get(ctx, rgScope.ResourceGroup, proximityPlacementGroupName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, c.Type()) } diff --git a/sources/azure/manual/compute-proximity-placement-group_test.go b/sources/azure/manual/compute-proximity-placement-group_test.go index 3aef3639..23921bae 100644 --- a/sources/azure/manual/compute-proximity-placement-group_test.go +++ b/sources/azure/manual/compute-proximity-placement-group_test.go @@ -39,7 +39,7 @@ func TestComputeProximityPlacementGroup(t *testing.T) { ProximityPlacementGroup: *ppg, }, nil) - wrapper := manual.NewComputeProximityPlacementGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeProximityPlacementGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, scope, ppgName, true) @@ -111,7 +111,7 @@ func TestComputeProximityPlacementGroup(t *testing.T) { ProximityPlacementGroup: *ppg, }, nil) - wrapper := manual.NewComputeProximityPlacementGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeProximityPlacementGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, scope, ppgName, true) @@ -152,7 +152,7 @@ func TestComputeProximityPlacementGroup(t *testing.T) { ProximityPlacementGroup: *ppg, }, nil) - wrapper := manual.NewComputeProximityPlacementGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeProximityPlacementGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, scope, ppgName, true) @@ -174,7 +174,7 @@ func TestComputeProximityPlacementGroup(t *testing.T) { mockClient.EXPECT().ListByResourceGroup(ctx, resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeProximityPlacementGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeProximityPlacementGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -220,7 +220,7 @@ func TestComputeProximityPlacementGroup(t *testing.T) { mockClient.EXPECT().ListByResourceGroup(ctx, resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeProximityPlacementGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeProximityPlacementGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -245,7 +245,7 @@ func TestComputeProximityPlacementGroup(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-ppg", nil).Return( armcompute.ProximityPlacementGroupsClientGetResponse{}, expectedErr) - wrapper := manual.NewComputeProximityPlacementGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeProximityPlacementGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, scope, "nonexistent-ppg", true) @@ -259,7 +259,7 @@ func TestComputeProximityPlacementGroup(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "", nil).Return( armcompute.ProximityPlacementGroupsClientGetResponse{}, errors.New("proximity placement group name is required")) - wrapper := manual.NewComputeProximityPlacementGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeProximityPlacementGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, scope, "", true) diff --git a/sources/azure/manual/compute-virtual-machine-extension.go b/sources/azure/manual/compute-virtual-machine-extension.go index 82c81ed5..5017576e 100644 --- a/sources/azure/manual/compute-virtual-machine-extension.go +++ b/sources/azure/manual/compute-virtual-machine-extension.go @@ -18,15 +18,14 @@ var ComputeVirtualMachineExtensionLookupByName = shared.NewItemTypeLookup("name" type computeVirtualMachineExtensionWrapper struct { client clients.VirtualMachineExtensionsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewComputeVirtualMachineExtension(client clients.VirtualMachineExtensionsClient, subscriptionID, resourceGroup string) sources.SearchableWrapper { +func NewComputeVirtualMachineExtension(client clients.VirtualMachineExtensionsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { return &computeVirtualMachineExtensionWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_COMPUTE_APPLICATION, azureshared.ComputeVirtualMachineExtension, ), @@ -46,11 +45,11 @@ func (c computeVirtualMachineExtensionWrapper) Get(ctx context.Context, scope st return nil, azureshared.QueryError(fmt.Errorf("extensionName cannot be empty"), scope, c.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - resp, err := c.client.Get(ctx, resourceGroup, virtualMachineName, extensionName, nil) + resp, err := c.client.Get(ctx, rgScope.ResourceGroup, virtualMachineName, extensionName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, c.Type()) } @@ -226,12 +225,12 @@ func (c computeVirtualMachineExtensionWrapper) Search(ctx context.Context, scope return nil, azureshared.QueryError(fmt.Errorf("virtualMachineName cannot be empty"), scope, c.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - resp, err := c.client.List(ctx, resourceGroup, virtualMachineName, nil) + resp, err := c.client.List(ctx, rgScope.ResourceGroup, virtualMachineName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, c.Type()) } diff --git a/sources/azure/manual/compute-virtual-machine-extension_test.go b/sources/azure/manual/compute-virtual-machine-extension_test.go index 92a57e84..c2cd9a87 100644 --- a/sources/azure/manual/compute-virtual-machine-extension_test.go +++ b/sources/azure/manual/compute-virtual-machine-extension_test.go @@ -39,7 +39,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { VirtualMachineExtension: *extension, }, nil) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -90,7 +90,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { VirtualMachineExtension: *extension, }, nil) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -156,7 +156,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { VirtualMachineExtension: *extension, }, nil) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -212,7 +212,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { VirtualMachineExtension: *extension, }, nil) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -271,7 +271,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { VirtualMachineExtension: *extension, }, nil) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -313,7 +313,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { VirtualMachineExtension: *extension, }, nil) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -360,7 +360,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { VirtualMachineExtension: *extension, }, nil) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -398,7 +398,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { t.Run("InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineExtensionsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -418,7 +418,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { t.Run("EmptyVirtualMachineName", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineExtensionsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -430,7 +430,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { t.Run("EmptyExtensionName", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineExtensionsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -446,7 +446,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { armcompute.VirtualMachineExtensionsClientGetResponse{}, errors.New("client error")) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -473,7 +473,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { }, }, nil) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -502,7 +502,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { t.Run("InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineExtensionsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -532,7 +532,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { armcompute.VirtualMachineExtensionsClientListResponse{}, errors.New("client error")) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -561,7 +561,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { }, }, nil) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -584,7 +584,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { t.Run("PotentialLinks", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineExtensionsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) links := wrapper.PotentialLinks() @@ -605,7 +605,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { t.Run("GetLookups", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineExtensionsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) lookups := wrapper.GetLookups() if len(lookups) != 2 { @@ -625,7 +625,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { t.Run("SearchLookups", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineExtensionsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) searchLookups := wrapper.SearchLookups() if len(searchLookups) != 1 { @@ -644,7 +644,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { t.Run("TerraformMappings", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineExtensionsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) mappings := wrapper.TerraformMappings() if len(mappings) != 1 { @@ -662,7 +662,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { t.Run("IAMPermissions", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineExtensionsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) permissions := wrapper.IAMPermissions() if len(permissions) != 1 { @@ -677,7 +677,7 @@ func TestComputeVirtualMachineExtension(t *testing.T) { t.Run("PredefinedRole", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineExtensionsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineExtension(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineExtension(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // PredefinedRole is available on the wrapper, not the adapter role := wrapper.(interface{ PredefinedRole() string }).PredefinedRole() diff --git a/sources/azure/manual/compute-virtual-machine-run-command.go b/sources/azure/manual/compute-virtual-machine-run-command.go index 0071af35..62bede15 100644 --- a/sources/azure/manual/compute-virtual-machine-run-command.go +++ b/sources/azure/manual/compute-virtual-machine-run-command.go @@ -19,15 +19,14 @@ var ComputeVirtualMachineRunCommandLookupByName = shared.NewItemTypeLookup("name type computeVirtualMachineRunCommandWrapper struct { client clients.VirtualMachineRunCommandsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewComputeVirtualMachineRunCommand(client clients.VirtualMachineRunCommandsClient, subscriptionID, resourceGroup string) sources.SearchableWrapper { +func NewComputeVirtualMachineRunCommand(client clients.VirtualMachineRunCommandsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { return &computeVirtualMachineRunCommandWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_COMPUTE_APPLICATION, azureshared.ComputeVirtualMachineRunCommand, ), @@ -56,12 +55,12 @@ func (s computeVirtualMachineRunCommandWrapper) Get(ctx context.Context, scope s return nil, azureshared.QueryError(errors.New("runCommandName cannot be empty"), scope, s.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - resp, err := s.client.GetByVirtualMachine(ctx, resourceGroup, virtualMachineName, runCommandName, nil) + resp, err := s.client.GetByVirtualMachine(ctx, rgScope.ResourceGroup, virtualMachineName, runCommandName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, s.Type()) } @@ -285,11 +284,11 @@ func (s computeVirtualMachineRunCommandWrapper) Search(ctx context.Context, scop return nil, azureshared.QueryError(errors.New("virtualMachineName cannot be empty"), scope, s.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - pager := s.client.NewListByVirtualMachinePager(resourceGroup, virtualMachineName, nil) + pager := s.client.NewListByVirtualMachinePager(rgScope.ResourceGroup, virtualMachineName, nil) var items []*sdp.Item for pager.More() { diff --git a/sources/azure/manual/compute-virtual-machine-run-command_test.go b/sources/azure/manual/compute-virtual-machine-run-command_test.go index a7ae6585..e33ef6ac 100644 --- a/sources/azure/manual/compute-virtual-machine-run-command_test.go +++ b/sources/azure/manual/compute-virtual-machine-run-command_test.go @@ -119,7 +119,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { VirtualMachineRunCommand: *runCommand, }, nil) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -170,7 +170,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { VirtualMachineRunCommand: *runCommand, }, nil) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -267,7 +267,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { VirtualMachineRunCommand: *runCommand, }, nil) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -317,7 +317,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { VirtualMachineRunCommand: *runCommand, }, nil) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -342,7 +342,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { t.Run("Get_ErrorHandling", func(t *testing.T) { t.Run("EmptyScope", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineRunCommandsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, "", shared.CompositeLookupKey(vmName, runCommandName), true) @@ -353,7 +353,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { t.Run("WrongQueryPartsCount", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineRunCommandsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -365,7 +365,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { t.Run("EmptyVirtualMachineName", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineRunCommandsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -377,7 +377,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { t.Run("EmptyRunCommandName", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineRunCommandsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -393,7 +393,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { armcompute.VirtualMachineRunCommandsClientGetByVirtualMachineResponse{}, errors.New("client error")) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) scope := subscriptionID + "." + resourceGroup @@ -425,7 +425,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { pager: mockPager, } - wrapper := manual.NewComputeVirtualMachineRunCommand(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -454,7 +454,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { t.Run("Search_ErrorHandling", func(t *testing.T) { t.Run("WrongQueryPartsCount", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineRunCommandsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -471,7 +471,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { t.Run("EmptyVirtualMachineName", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineRunCommandsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -495,7 +495,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { pager: errorPager, } - wrapper := manual.NewComputeVirtualMachineRunCommand(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -536,7 +536,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { pager: mockPager, } - wrapper := manual.NewComputeVirtualMachineRunCommand(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -559,7 +559,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { t.Run("PotentialLinks", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineRunCommandsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) potentialLinks := wrapper.PotentialLinks() expectedLinks := map[shared.ItemType]bool{ @@ -587,7 +587,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { t.Run("IAMPermissions", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineRunCommandsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) permissions := wrapper.IAMPermissions() expectedPermission := "Microsoft.Compute/virtualMachines/runCommands/read" @@ -603,7 +603,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { t.Run("TerraformMappings", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineRunCommandsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) mappings := wrapper.TerraformMappings() if len(mappings) != 1 { @@ -621,7 +621,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { t.Run("GetLookups", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineRunCommandsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) lookups := wrapper.GetLookups() if len(lookups) != 2 { @@ -631,7 +631,7 @@ func TestComputeVirtualMachineRunCommand(t *testing.T) { t.Run("SearchLookups", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineRunCommandsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineRunCommand(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) searchLookups := wrapper.SearchLookups() if len(searchLookups) != 1 { diff --git a/sources/azure/manual/compute-virtual-machine-scale-set.go b/sources/azure/manual/compute-virtual-machine-scale-set.go index 9d51aa47..44b2354e 100644 --- a/sources/azure/manual/compute-virtual-machine-scale-set.go +++ b/sources/azure/manual/compute-virtual-machine-scale-set.go @@ -21,15 +21,14 @@ var ComputeVirtualMachineScaleSetLookupByName = shared.NewItemTypeLookup("name", type computeVirtualMachineScaleSetWrapper struct { client clients.VirtualMachineScaleSetsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewComputeVirtualMachineScaleSet(client clients.VirtualMachineScaleSetsClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewComputeVirtualMachineScaleSet(client clients.VirtualMachineScaleSetsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &computeVirtualMachineScaleSetWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_COMPUTE_APPLICATION, azureshared.ComputeVirtualMachineScaleSet, ), @@ -38,11 +37,11 @@ func NewComputeVirtualMachineScaleSet(client clients.VirtualMachineScaleSetsClie // ref: https://linear.app/overmind/issue/ENG-2114/create-microsoftcomputevirtualmachinescalesets-adapter func (c computeVirtualMachineScaleSetWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - pager := c.client.NewListPager(resourceGroup, nil) + pager := c.client.NewListPager(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -63,11 +62,12 @@ func (c computeVirtualMachineScaleSetWrapper) List(ctx context.Context, scope st } func (c computeVirtualMachineScaleSetWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return } - pager := c.client.NewListPager(resourceGroup, nil) + pager := c.client.NewListPager(rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) if err != nil { @@ -95,11 +95,11 @@ func (c computeVirtualMachineScaleSetWrapper) Get(ctx context.Context, scope str if scaleSetName == "" { return nil, azureshared.QueryError(errors.New("scaleSetName cannot be empty"), scope, c.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - scaleSet, err := c.client.Get(ctx, resourceGroup, scaleSetName, nil) + scaleSet, err := c.client.Get(ctx, rgScope.ResourceGroup, scaleSetName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, c.Type()) } diff --git a/sources/azure/manual/compute-virtual-machine-scale-set_test.go b/sources/azure/manual/compute-virtual-machine-scale-set_test.go index 3c037329..fbdc03e9 100644 --- a/sources/azure/manual/compute-virtual-machine-scale-set_test.go +++ b/sources/azure/manual/compute-virtual-machine-scale-set_test.go @@ -39,7 +39,7 @@ func TestComputeVirtualMachineScaleSet(t *testing.T) { VirtualMachineScaleSet: *scaleSet, }, nil) - wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], scaleSetName, true) @@ -331,7 +331,7 @@ func TestComputeVirtualMachineScaleSet(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockVirtualMachineScaleSetsClient(ctrl) - wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty string name @@ -351,7 +351,7 @@ func TestComputeVirtualMachineScaleSet(t *testing.T) { VirtualMachineScaleSet: *scaleSet, }, nil) - wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "test-vmss", true) @@ -416,7 +416,7 @@ func TestComputeVirtualMachineScaleSet(t *testing.T) { VirtualMachineScaleSet: *scaleSet, }, nil) - wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "test-vmss", true) @@ -452,7 +452,7 @@ func TestComputeVirtualMachineScaleSet(t *testing.T) { mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -501,7 +501,7 @@ func TestComputeVirtualMachineScaleSet(t *testing.T) { mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) wg := &sync.WaitGroup{} @@ -551,7 +551,7 @@ func TestComputeVirtualMachineScaleSet(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-vmss", nil).Return( armcompute.VirtualMachineScaleSetsClientGetResponse{}, expectedErr) - wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-vmss", true) @@ -575,7 +575,7 @@ func TestComputeVirtualMachineScaleSet(t *testing.T) { mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachineScaleSet(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) diff --git a/sources/azure/manual/compute-virtual-machine.go b/sources/azure/manual/compute-virtual-machine.go index 2daf42e5..5e5f78e9 100644 --- a/sources/azure/manual/compute-virtual-machine.go +++ b/sources/azure/manual/compute-virtual-machine.go @@ -21,16 +21,15 @@ var ComputeVirtualMachineLookupByName = shared.NewItemTypeLookup("name", azuresh type computeVirtualMachineWrapper struct { client clients.VirtualMachinesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } // NewComputeVirtualMachine creates a new computeVirtualMachineWrapper instance -func NewComputeVirtualMachine(client clients.VirtualMachinesClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewComputeVirtualMachine(client clients.VirtualMachinesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &computeVirtualMachineWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_COMPUTE_APPLICATION, azureshared.ComputeVirtualMachine, ), @@ -105,11 +104,11 @@ func (c computeVirtualMachineWrapper) GetLookups() sources.ItemTypeLookups { func (c computeVirtualMachineWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { vmName := queryParts[0] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - resp, err := c.client.Get(ctx, resourceGroup, vmName, nil) + resp, err := c.client.Get(ctx, rgScope.ResourceGroup, vmName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, c.Type()) } @@ -127,11 +126,11 @@ func (c computeVirtualMachineWrapper) Get(ctx context.Context, scope string, que // List lists virtual machines in the resource group and converts them to sdp.Items. // Reference: https://learn.microsoft.com/en-us/rest/api/compute/virtual-machines/list func (c computeVirtualMachineWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) } - pager := c.client.NewListPager(resourceGroup, nil) + pager := c.client.NewListPager(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -156,11 +155,12 @@ func (c computeVirtualMachineWrapper) List(ctx context.Context, scope string) ([ } func (c computeVirtualMachineWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = c.ResourceGroup() + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return } - pager := c.client.NewListPager(resourceGroup, nil) + pager := c.client.NewListPager(rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) @@ -210,17 +210,17 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if vm.Properties.StorageProfile.OSDisk.ManagedDisk != nil && vm.Properties.StorageProfile.OSDisk.ManagedDisk.ID != nil { diskName := azureshared.ExtractResourceName(*vm.Properties.StorageProfile.OSDisk.ManagedDisk.ID) if diskName != "" { - scope := c.DefaultScope() + linkScope := scope // Check if disk is in a different resource group if extractedScope := azureshared.ExtractScopeFromResourceID(*vm.Properties.StorageProfile.OSDisk.ManagedDisk.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeDisk.String(), Method: sdp.QueryMethod_GET, Query: diskName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If disk changes → VM affected (In: true) @@ -233,16 +233,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if vm.Properties.StorageProfile.OSDisk.ManagedDisk.DiskEncryptionSet != nil && vm.Properties.StorageProfile.OSDisk.ManagedDisk.DiskEncryptionSet.ID != nil { diskEncryptionSetName := azureshared.ExtractResourceName(*vm.Properties.StorageProfile.OSDisk.ManagedDisk.DiskEncryptionSet.ID) if diskEncryptionSetName != "" { - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(*vm.Properties.StorageProfile.OSDisk.ManagedDisk.DiskEncryptionSet.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeDiskEncryptionSet.String(), Method: sdp.QueryMethod_GET, Query: diskEncryptionSetName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If encryption set changes → disk encryption affected (In: true) @@ -260,17 +260,17 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if dataDisk.ManagedDisk != nil && dataDisk.ManagedDisk.ID != nil { diskName := azureshared.ExtractResourceName(*dataDisk.ManagedDisk.ID) if diskName != "" { - scope := c.DefaultScope() + linkScope := scope // Check if disk is in a different resource group if extractedScope := azureshared.ExtractScopeFromResourceID(*dataDisk.ManagedDisk.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeDisk.String(), Method: sdp.QueryMethod_GET, Query: diskName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If disk changes → VM affected (In: true) @@ -283,16 +283,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if dataDisk.ManagedDisk.DiskEncryptionSet != nil && dataDisk.ManagedDisk.DiskEncryptionSet.ID != nil { diskEncryptionSetName := azureshared.ExtractResourceName(*dataDisk.ManagedDisk.DiskEncryptionSet.ID) if diskEncryptionSetName != "" { - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(*dataDisk.ManagedDisk.DiskEncryptionSet.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeDiskEncryptionSet.String(), Method: sdp.QueryMethod_GET, Query: diskEncryptionSetName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If encryption set changes → disk encryption affected (In: true) @@ -312,17 +312,17 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if nic.ID != nil { nicName := azureshared.ExtractResourceName(*nic.ID) if nicName != "" { - scope := c.DefaultScope() + linkScope := scope // Check if NIC is in a different resource group if extractedScope := azureshared.ExtractScopeFromResourceID(*nic.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.NetworkNetworkInterface.String(), Method: sdp.QueryMethod_GET, Query: nicName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If NIC changes → VM network connectivity affected (In: true) @@ -339,17 +339,17 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if vm.Properties != nil && vm.Properties.AvailabilitySet != nil && vm.Properties.AvailabilitySet.ID != nil { availabilitySetName := azureshared.ExtractResourceName(*vm.Properties.AvailabilitySet.ID) if availabilitySetName != "" { - scope := c.DefaultScope() + linkScope := scope // Check if availability set is in a different resource group if extractedScope := azureshared.ExtractScopeFromResourceID(*vm.Properties.AvailabilitySet.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeAvailabilitySet.String(), Method: sdp.QueryMethod_GET, Query: availabilitySetName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If availability set changes → VM placement affected (In: true) @@ -364,16 +364,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if vm.Properties != nil && vm.Properties.ProximityPlacementGroup != nil && vm.Properties.ProximityPlacementGroup.ID != nil { proximityPlacementGroupName := azureshared.ExtractResourceName(*vm.Properties.ProximityPlacementGroup.ID) if proximityPlacementGroupName != "" { - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(*vm.Properties.ProximityPlacementGroup.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeProximityPlacementGroup.String(), Method: sdp.QueryMethod_GET, Query: proximityPlacementGroupName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If proximity placement group changes → VM placement affected (In: true) @@ -388,16 +388,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if vm.Properties != nil && vm.Properties.HostGroup != nil && vm.Properties.HostGroup.ID != nil { hostGroupName := azureshared.ExtractResourceName(*vm.Properties.HostGroup.ID) if hostGroupName != "" { - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(*vm.Properties.HostGroup.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeDedicatedHostGroup.String(), Method: sdp.QueryMethod_GET, Query: hostGroupName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If host group changes → VM host placement affected (In: true) @@ -412,16 +412,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if vm.Properties != nil && vm.Properties.CapacityReservation != nil && vm.Properties.CapacityReservation.CapacityReservationGroup != nil && vm.Properties.CapacityReservation.CapacityReservationGroup.ID != nil { capacityReservationGroupName := azureshared.ExtractResourceName(*vm.Properties.CapacityReservation.CapacityReservationGroup.ID) if capacityReservationGroupName != "" { - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(*vm.Properties.CapacityReservation.CapacityReservationGroup.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeCapacityReservationGroup.String(), Method: sdp.QueryMethod_GET, Query: capacityReservationGroupName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If capacity reservation group changes → VM capacity reservation affected (In: true) @@ -436,16 +436,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if vm.Properties != nil && vm.Properties.VirtualMachineScaleSet != nil && vm.Properties.VirtualMachineScaleSet.ID != nil { vmssName := azureshared.ExtractResourceName(*vm.Properties.VirtualMachineScaleSet.ID) if vmssName != "" { - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(*vm.Properties.VirtualMachineScaleSet.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeVirtualMachineScaleSet.String(), Method: sdp.QueryMethod_GET, Query: vmssName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If VMSS changes → VM configuration affected (In: true) @@ -462,16 +462,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if strings.Contains(*vm.ManagedBy, "/virtualMachineScaleSets/") { vmssName := azureshared.ExtractPathParamsFromResourceID(*vm.ManagedBy, []string{"virtualMachineScaleSets"}) if len(vmssName) > 0 && vmssName[0] != "" { - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(*vm.ManagedBy); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeVirtualMachineScaleSet.String(), Method: sdp.QueryMethod_GET, Query: vmssName[0], - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If VMSS changes → VM configuration affected (In: true) @@ -496,16 +496,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput galleryName := params[0] imageName := params[1] versionName := params[2] - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(imageID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeSharedGalleryImage.String(), Method: sdp.QueryMethod_GET, Query: shared.CompositeLookupKey(galleryName, imageName, versionName), - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If image version changes → VM image affected (In: true) @@ -517,16 +517,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput // Custom Image imageName := azureshared.ExtractResourceName(imageID) if imageName != "" { - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(imageID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeImage.String(), Method: sdp.QueryMethod_GET, Query: imageName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If image changes → VM image affected (In: true) @@ -544,16 +544,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput for identityID := range vm.Identity.UserAssignedIdentities { identityName := azureshared.ExtractResourceName(identityID) if identityName != "" { - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(identityID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ManagedIdentityUserAssignedIdentity.String(), Method: sdp.QueryMethod_GET, Query: identityName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If identity changes → VM identity access affected (In: true) @@ -571,16 +571,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if secret.SourceVault != nil && secret.SourceVault.ID != nil { vaultName := azureshared.ExtractResourceName(*secret.SourceVault.ID) if vaultName != "" { - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(*secret.SourceVault.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.KeyVaultVault.String(), Method: sdp.QueryMethod_GET, Query: vaultName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If Key Vault changes → VM secrets access affected (In: true) @@ -601,16 +601,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if vm.Properties.StorageProfile.OSDisk.EncryptionSettings.DiskEncryptionKey != nil && vm.Properties.StorageProfile.OSDisk.EncryptionSettings.DiskEncryptionKey.SourceVault != nil && vm.Properties.StorageProfile.OSDisk.EncryptionSettings.DiskEncryptionKey.SourceVault.ID != nil { vaultName := azureshared.ExtractResourceName(*vm.Properties.StorageProfile.OSDisk.EncryptionSettings.DiskEncryptionKey.SourceVault.ID) if vaultName != "" { - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(*vm.Properties.StorageProfile.OSDisk.EncryptionSettings.DiskEncryptionKey.SourceVault.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.KeyVaultVault.String(), Method: sdp.QueryMethod_GET, Query: vaultName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If Key Vault changes → disk encryption affected (In: true) @@ -623,16 +623,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput if vm.Properties.StorageProfile.OSDisk.EncryptionSettings.KeyEncryptionKey != nil && vm.Properties.StorageProfile.OSDisk.EncryptionSettings.KeyEncryptionKey.SourceVault != nil && vm.Properties.StorageProfile.OSDisk.EncryptionSettings.KeyEncryptionKey.SourceVault.ID != nil { vaultName := azureshared.ExtractResourceName(*vm.Properties.StorageProfile.OSDisk.EncryptionSettings.KeyEncryptionKey.SourceVault.ID) if vaultName != "" { - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(*vm.Properties.StorageProfile.OSDisk.EncryptionSettings.KeyEncryptionKey.SourceVault.ID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.KeyVaultVault.String(), Method: sdp.QueryMethod_GET, Query: vaultName, - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If Key Vault changes → disk encryption affected (In: true) @@ -656,16 +656,16 @@ func (c computeVirtualMachineWrapper) azureVirtualMachineToSDPItem(vm *armcomput galleryName := params[0] appName := params[1] versionName := params[2] - scope := c.DefaultScope() + linkScope := scope if extractedScope := azureshared.ExtractScopeFromResourceID(packageRefID); extractedScope != "" { - scope = extractedScope + linkScope = extractedScope } sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ Query: &sdp.Query{ Type: azureshared.ComputeSharedGalleryApplicationVersion.String(), Method: sdp.QueryMethod_GET, Query: shared.CompositeLookupKey(galleryName, appName, versionName), - Scope: scope, + Scope: linkScope, }, BlastPropagation: &sdp.BlastPropagation{ In: true, // If application version changes → VM application affected (In: true) diff --git a/sources/azure/manual/compute-virtual-machine_test.go b/sources/azure/manual/compute-virtual-machine_test.go index ef7d52aa..857b9b64 100644 --- a/sources/azure/manual/compute-virtual-machine_test.go +++ b/sources/azure/manual/compute-virtual-machine_test.go @@ -38,7 +38,7 @@ func TestComputeVirtualMachine(t *testing.T) { VirtualMachine: *vm, }, nil) - wrapper := manual.NewComputeVirtualMachine(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachine(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], vmName, true) @@ -196,7 +196,7 @@ func TestComputeVirtualMachine(t *testing.T) { VirtualMachine: *vm, }, nil) - wrapper := manual.NewComputeVirtualMachine(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachine(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "test-vm", true) @@ -232,7 +232,7 @@ func TestComputeVirtualMachine(t *testing.T) { mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeVirtualMachine(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachine(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -281,7 +281,7 @@ func TestComputeVirtualMachine(t *testing.T) { mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewComputeVirtualMachine(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachine(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) wg := &sync.WaitGroup{} @@ -331,7 +331,7 @@ func TestComputeVirtualMachine(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-vm", nil).Return( armcompute.VirtualMachinesClientGetResponse{}, expectedErr) - wrapper := manual.NewComputeVirtualMachine(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewComputeVirtualMachine(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-vm", true) diff --git a/sources/azure/manual/dbforpostgresql-database.go b/sources/azure/manual/dbforpostgresql-database.go index 927caaaa..0d6f3a8a 100644 --- a/sources/azure/manual/dbforpostgresql-database.go +++ b/sources/azure/manual/dbforpostgresql-database.go @@ -17,15 +17,14 @@ var DBforPostgreSQLDatabaseLookupByName = shared.NewItemTypeLookup("name", azure type dbforPostgreSQLDatabaseWrapper struct { client clients.PostgreSQLDatabasesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewDBforPostgreSQLDatabase(client clients.PostgreSQLDatabasesClient, subscriptionID, resourceGroup string) sources.SearchableWrapper { +func NewDBforPostgreSQLDatabase(client clients.PostgreSQLDatabasesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { return &dbforPostgreSQLDatabaseWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_DATABASE, azureshared.DBforPostgreSQLDatabase, ), @@ -46,11 +45,11 @@ func (s dbforPostgreSQLDatabaseWrapper) Get(ctx context.Context, scope string, q serverName := queryParts[0] databaseName := queryParts[1] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - resp, err := s.client.Get(ctx, resourceGroup, serverName, databaseName) + resp, err := s.client.Get(ctx, rgScope.ResourceGroup, serverName, databaseName) if err != nil { return nil, azureshared.QueryError(err, scope, s.Type()) } @@ -114,11 +113,11 @@ func (s dbforPostgreSQLDatabaseWrapper) Search(ctx context.Context, scope string } serverName := queryParts[0] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - pager := s.client.ListByServer(ctx, resourceGroup, serverName) + pager := s.client.ListByServer(ctx, rgScope.ResourceGroup, serverName) var items []*sdp.Item for pager.More() { diff --git a/sources/azure/manual/dbforpostgresql-database_test.go b/sources/azure/manual/dbforpostgresql-database_test.go index aad0ead4..43d71ff5 100644 --- a/sources/azure/manual/dbforpostgresql-database_test.go +++ b/sources/azure/manual/dbforpostgresql-database_test.go @@ -80,7 +80,7 @@ func TestDBforPostgreSQLDatabase(t *testing.T) { }, nil) testClient := &testPostgreSQLDatabasesClient{MockPostgreSQLDatabasesClient: mockClient} - wrapper := manual.NewDBforPostgreSQLDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Get requires serverName and databaseName as query parts @@ -135,7 +135,7 @@ func TestDBforPostgreSQLDatabase(t *testing.T) { mockClient := mocks.NewMockPostgreSQLDatabasesClient(ctrl) testClient := &testPostgreSQLDatabasesClient{MockPostgreSQLDatabasesClient: mockClient} - wrapper := manual.NewDBforPostgreSQLDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with insufficient query parts (only server name) @@ -165,7 +165,7 @@ func TestDBforPostgreSQLDatabase(t *testing.T) { pager: mockPager, } - wrapper := manual.NewDBforPostgreSQLDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -219,7 +219,7 @@ func TestDBforPostgreSQLDatabase(t *testing.T) { pager: mockPager, } - wrapper := manual.NewDBforPostgreSQLDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -246,7 +246,7 @@ func TestDBforPostgreSQLDatabase(t *testing.T) { mockClient := mocks.NewMockPostgreSQLDatabasesClient(ctrl) testClient := &testPostgreSQLDatabasesClient{MockPostgreSQLDatabasesClient: mockClient} - wrapper := manual.NewDBforPostgreSQLDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Test Search directly with no query parts - should return error before calling ListByServer _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) @@ -263,7 +263,7 @@ func TestDBforPostgreSQLDatabase(t *testing.T) { armpostgresqlflexibleservers.DatabasesClientGetResponse{}, expectedErr) testClient := &testPostgreSQLDatabasesClient{MockPostgreSQLDatabasesClient: mockClient} - wrapper := manual.NewDBforPostgreSQLDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) query := shared.CompositeLookupKey(serverName, "nonexistent-database") @@ -283,7 +283,7 @@ func TestDBforPostgreSQLDatabase(t *testing.T) { pager: errorPager, } - wrapper := manual.NewDBforPostgreSQLDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) diff --git a/sources/azure/manual/dbforpostgresql-flexible-server.go b/sources/azure/manual/dbforpostgresql-flexible-server.go index b02239e2..c77f080b 100644 --- a/sources/azure/manual/dbforpostgresql-flexible-server.go +++ b/sources/azure/manual/dbforpostgresql-flexible-server.go @@ -19,15 +19,14 @@ var DBforPostgreSQLFlexibleServerLookupByName = shared.NewItemTypeLookup("name", type dbforPostgreSQLFlexibleServerWrapper struct { client clients.PostgreSQLFlexibleServersClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewDBforPostgreSQLFlexibleServer(client clients.PostgreSQLFlexibleServersClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewDBforPostgreSQLFlexibleServer(client clients.PostgreSQLFlexibleServersClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &dbforPostgreSQLFlexibleServerWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_DATABASE, azureshared.DBforPostgreSQLFlexibleServer, ), @@ -44,11 +43,11 @@ func (s dbforPostgreSQLFlexibleServerWrapper) Get(ctx context.Context, scope str return nil, azureshared.QueryError(errors.New("serverName is empty"), scope, s.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - resp, err := s.client.Get(ctx, resourceGroup, serverName, nil) + resp, err := s.client.Get(ctx, rgScope.ResourceGroup, serverName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, s.Type()) } @@ -58,11 +57,11 @@ func (s dbforPostgreSQLFlexibleServerWrapper) Get(ctx context.Context, scope str // ref: https://learn.microsoft.com/en-us/rest/api/postgresql/servers/list-by-resource-group?view=rest-postgresql-2025-08-01&tabs=HTTP func (s dbforPostgreSQLFlexibleServerWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - pager := s.client.ListByResourceGroup(ctx, resourceGroup, nil) + pager := s.client.ListByResourceGroup(ctx, rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { page, err := pager.NextPage(ctx) diff --git a/sources/azure/manual/dbforpostgresql-flexible-server_test.go b/sources/azure/manual/dbforpostgresql-flexible-server_test.go index a6b92348..600ac3b6 100644 --- a/sources/azure/manual/dbforpostgresql-flexible-server_test.go +++ b/sources/azure/manual/dbforpostgresql-flexible-server_test.go @@ -80,7 +80,7 @@ func TestDBforPostgreSQLFlexibleServer(t *testing.T) { }, nil) testClient := &testPostgreSQLFlexibleServersClient{MockPostgreSQLFlexibleServersClient: mockClient} - wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) @@ -229,7 +229,7 @@ func TestDBforPostgreSQLFlexibleServer(t *testing.T) { }, nil) testClient := &testPostgreSQLFlexibleServersClient{MockPostgreSQLFlexibleServersClient: mockClient} - wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) @@ -292,7 +292,7 @@ func TestDBforPostgreSQLFlexibleServer(t *testing.T) { }, nil) testClient := &testPostgreSQLFlexibleServersClient{MockPostgreSQLFlexibleServersClient: mockClient} - wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) @@ -332,7 +332,7 @@ func TestDBforPostgreSQLFlexibleServer(t *testing.T) { mockClient := mocks.NewMockPostgreSQLFlexibleServersClient(ctrl) testClient := &testPostgreSQLFlexibleServersClient{MockPostgreSQLFlexibleServersClient: mockClient} - wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty query @@ -362,7 +362,7 @@ func TestDBforPostgreSQLFlexibleServer(t *testing.T) { pager: mockPager, } - wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -415,7 +415,7 @@ func TestDBforPostgreSQLFlexibleServer(t *testing.T) { pager: mockPager, } - wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -446,7 +446,7 @@ func TestDBforPostgreSQLFlexibleServer(t *testing.T) { armpostgresqlflexibleservers.ServersClientGetResponse{}, expectedErr) testClient := &testPostgreSQLFlexibleServersClient{MockPostgreSQLFlexibleServersClient: mockClient} - wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-server", true) @@ -465,7 +465,7 @@ func TestDBforPostgreSQLFlexibleServer(t *testing.T) { pager: errorPager, } - wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -502,7 +502,7 @@ func TestDBforPostgreSQLFlexibleServer(t *testing.T) { }, nil) testClient := &testPostgreSQLFlexibleServersClient{MockPostgreSQLFlexibleServersClient: mockClient} - wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) @@ -610,7 +610,7 @@ func TestDBforPostgreSQLFlexibleServer(t *testing.T) { }, nil) testClient := &testPostgreSQLFlexibleServersClient{MockPostgreSQLFlexibleServersClient: mockClient} - wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], replicaServerName, true) @@ -648,7 +648,7 @@ func TestDBforPostgreSQLFlexibleServer(t *testing.T) { mockClient := mocks.NewMockPostgreSQLFlexibleServersClient(ctrl) testClient := &testPostgreSQLFlexibleServersClient{MockPostgreSQLFlexibleServersClient: mockClient} - wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDBforPostgreSQLFlexibleServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) potentialLinks := wrapper.PotentialLinks() expectedLinks := map[shared.ItemType]bool{ diff --git a/sources/azure/manual/documentdb-database-accounts.go b/sources/azure/manual/documentdb-database-accounts.go index 79080629..51afb676 100644 --- a/sources/azure/manual/documentdb-database-accounts.go +++ b/sources/azure/manual/documentdb-database-accounts.go @@ -18,15 +18,14 @@ var DocumentDBDatabaseAccountsLookupByName = shared.NewItemTypeLookup("name", az type documentDBDatabaseAccountsWrapper struct { client clients.DocumentDBDatabaseAccountsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewDocumentDBDatabaseAccounts(client clients.DocumentDBDatabaseAccountsClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewDocumentDBDatabaseAccounts(client clients.DocumentDBDatabaseAccountsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &documentDBDatabaseAccountsWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_DATABASE, azureshared.DocumentDBDatabaseAccounts, ), @@ -34,11 +33,11 @@ func NewDocumentDBDatabaseAccounts(client clients.DocumentDBDatabaseAccountsClie } func (s documentDBDatabaseAccountsWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - pager := s.client.ListByResourceGroup(resourceGroup) + pager := s.client.ListByResourceGroup(rgScope.ResourceGroup) var items []*sdp.Item for pager.More() { @@ -78,11 +77,11 @@ func (s documentDBDatabaseAccountsWrapper) Get(ctx context.Context, scope string } } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - resp, err := s.client.Get(ctx, resourceGroup, accountName) + resp, err := s.client.Get(ctx, rgScope.ResourceGroup, accountName) if err != nil { return nil, azureshared.QueryError(err, scope, s.Type()) } diff --git a/sources/azure/manual/documentdb-database-accounts_test.go b/sources/azure/manual/documentdb-database-accounts_test.go index e0525c57..e6905f10 100644 --- a/sources/azure/manual/documentdb-database-accounts_test.go +++ b/sources/azure/manual/documentdb-database-accounts_test.go @@ -80,7 +80,7 @@ func TestDocumentDBDatabaseAccounts(t *testing.T) { DatabaseAccountGetResults: *account, }, nil) - wrapper := manual.NewDocumentDBDatabaseAccounts(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewDocumentDBDatabaseAccounts(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], accountName, true) @@ -203,7 +203,7 @@ func TestDocumentDBDatabaseAccounts(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockDocumentDBDatabaseAccountsClient(ctrl) - wrapper := manual.NewDocumentDBDatabaseAccounts(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewDocumentDBDatabaseAccounts(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty name @@ -227,7 +227,7 @@ func TestDocumentDBDatabaseAccounts(t *testing.T) { DatabaseAccountGetResults: *account, }, nil) - wrapper := manual.NewDocumentDBDatabaseAccounts(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewDocumentDBDatabaseAccounts(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], accountName, true) @@ -245,7 +245,7 @@ func TestDocumentDBDatabaseAccounts(t *testing.T) { DatabaseAccountGetResults: *account, }, nil) - wrapper := manual.NewDocumentDBDatabaseAccounts(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewDocumentDBDatabaseAccounts(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], accountName, true) @@ -281,7 +281,7 @@ func TestDocumentDBDatabaseAccounts(t *testing.T) { pager: mockPager, } - wrapper := manual.NewDocumentDBDatabaseAccounts(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDocumentDBDatabaseAccounts(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -316,7 +316,7 @@ func TestDocumentDBDatabaseAccounts(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-account").Return( armcosmos.DatabaseAccountsClientGetResponse{}, expectedErr) - wrapper := manual.NewDocumentDBDatabaseAccounts(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewDocumentDBDatabaseAccounts(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-account", true) @@ -336,7 +336,7 @@ func TestDocumentDBDatabaseAccounts(t *testing.T) { mockClient.EXPECT().ListByResourceGroup(resourceGroup).Return(errorPager) - wrapper := manual.NewDocumentDBDatabaseAccounts(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewDocumentDBDatabaseAccounts(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -360,7 +360,7 @@ func TestDocumentDBDatabaseAccounts(t *testing.T) { DatabaseAccountGetResults: *account, }, nil) - wrapper := manual.NewDocumentDBDatabaseAccounts(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewDocumentDBDatabaseAccounts(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], accountName, true) diff --git a/sources/azure/manual/keyvault-managed-hsm.go b/sources/azure/manual/keyvault-managed-hsm.go index 6b0ffaaf..12be8064 100644 --- a/sources/azure/manual/keyvault-managed-hsm.go +++ b/sources/azure/manual/keyvault-managed-hsm.go @@ -21,15 +21,14 @@ var KeyVaultManagedHSMsLookupByName = shared.NewItemTypeLookup("name", azureshar type keyvaultManagedHSMsWrapper struct { client clients.ManagedHSMsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewKeyVaultManagedHSM(client clients.ManagedHSMsClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewKeyVaultManagedHSM(client clients.ManagedHSMsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &keyvaultManagedHSMsWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_SECURITY, azureshared.KeyVaultManagedHSM, ), @@ -38,11 +37,11 @@ func NewKeyVaultManagedHSM(client clients.ManagedHSMsClient, subscriptionID, res // ref: https://learn.microsoft.com/en-us/rest/api/keyvault/managedhsm/managed-hsms/list-by-resource-group?view=rest-keyvault-managedhsm-2024-11-01&tabs=HTTP func (k keyvaultManagedHSMsWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = k.ResourceGroup() + rgScope, err := k.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, k.Type()) } - pager := k.client.NewListByResourceGroupPager(resourceGroup, nil) + pager := k.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -68,11 +67,12 @@ func (k keyvaultManagedHSMsWrapper) List(ctx context.Context, scope string) ([]* } func (k keyvaultManagedHSMsWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = k.ResourceGroup() + rgScope, err := k.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, k.Type())) + return } - pager := k.client.NewListByResourceGroupPager(resourceGroup, nil) + pager := k.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) @@ -327,11 +327,11 @@ func (k keyvaultManagedHSMsWrapper) Get(ctx context.Context, scope string, query return nil, azureshared.QueryError(errors.New("name cannot be empty"), scope, k.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = k.ResourceGroup() + rgScope, err := k.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, k.Type()) } - resp, err := k.client.Get(ctx, resourceGroup, name, nil) + resp, err := k.client.Get(ctx, rgScope.ResourceGroup, name, nil) if err != nil { return nil, azureshared.QueryError(err, scope, k.Type()) } diff --git a/sources/azure/manual/keyvault-managed-hsm_test.go b/sources/azure/manual/keyvault-managed-hsm_test.go index bb3d81ea..056613cc 100644 --- a/sources/azure/manual/keyvault-managed-hsm_test.go +++ b/sources/azure/manual/keyvault-managed-hsm_test.go @@ -82,7 +82,7 @@ func TestKeyVaultManagedHSM(t *testing.T) { ManagedHsm: *hsm, }, nil) - wrapper := manual.NewKeyVaultManagedHSM(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultManagedHSM(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], hsmName, true) @@ -249,7 +249,7 @@ func TestKeyVaultManagedHSM(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockManagedHSMsClient(ctrl) - wrapper := manual.NewKeyVaultManagedHSM(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultManagedHSM(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty name @@ -273,7 +273,7 @@ func TestKeyVaultManagedHSM(t *testing.T) { ManagedHsm: *hsm, }, nil) - wrapper := manual.NewKeyVaultManagedHSM(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultManagedHSM(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], hsmName, true) @@ -291,7 +291,7 @@ func TestKeyVaultManagedHSM(t *testing.T) { ManagedHsm: *hsm, }, nil) - wrapper := manual.NewKeyVaultManagedHSM(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultManagedHSM(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], hsmName, true) @@ -327,7 +327,7 @@ func TestKeyVaultManagedHSM(t *testing.T) { pager: mockPager, } - wrapper := manual.NewKeyVaultManagedHSM(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultManagedHSM(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -366,7 +366,7 @@ func TestKeyVaultManagedHSM(t *testing.T) { pager: errorPager, } - wrapper := manual.NewKeyVaultManagedHSM(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultManagedHSM(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -408,7 +408,7 @@ func TestKeyVaultManagedHSM(t *testing.T) { pager: mockPager, } - wrapper := manual.NewKeyVaultManagedHSM(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultManagedHSM(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -449,7 +449,7 @@ func TestKeyVaultManagedHSM(t *testing.T) { pager: mockPager, } - wrapper := manual.NewKeyVaultManagedHSM(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultManagedHSM(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) wg := &sync.WaitGroup{} @@ -513,7 +513,7 @@ func TestKeyVaultManagedHSM(t *testing.T) { pager: errorPager, } - wrapper := manual.NewKeyVaultManagedHSM(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultManagedHSM(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) var errs []error @@ -563,7 +563,7 @@ func TestKeyVaultManagedHSM(t *testing.T) { pager: mockPager, } - wrapper := manual.NewKeyVaultManagedHSM(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultManagedHSM(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) wg := &sync.WaitGroup{} @@ -615,7 +615,7 @@ func TestKeyVaultManagedHSM(t *testing.T) { armkeyvault.ManagedHsmsClientGetResponse{}, errors.New("client error")) - wrapper := manual.NewKeyVaultManagedHSM(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultManagedHSM(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], hsmName, true) @@ -634,7 +634,7 @@ func TestKeyVaultManagedHSM(t *testing.T) { ManagedHsm: *hsm, }, nil) - wrapper := manual.NewKeyVaultManagedHSM(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultManagedHSM(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], hsmName, true) diff --git a/sources/azure/manual/keyvault-secret.go b/sources/azure/manual/keyvault-secret.go index 27004cd9..b0aba42c 100644 --- a/sources/azure/manual/keyvault-secret.go +++ b/sources/azure/manual/keyvault-secret.go @@ -18,15 +18,14 @@ var KeyVaultSecretLookupByName = shared.NewItemTypeLookup("name", azureshared.Ke type keyvaultSecretWrapper struct { client clients.SecretsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewKeyVaultSecret(client clients.SecretsClient, subscriptionID, resourceGroup string) sources.SearchableWrapper { +func NewKeyVaultSecret(client clients.SecretsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { return &keyvaultSecretWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_SECURITY, azureshared.KeyVaultSecret, ), @@ -48,11 +47,11 @@ func (k keyvaultSecretWrapper) Get(ctx context.Context, scope string, queryParts return nil, azureshared.QueryError(errors.New("secretName cannot be empty"), scope, k.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = k.ResourceGroup() + rgScope, err := k.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, k.Type()) } - resp, err := k.client.Get(ctx, resourceGroup, vaultName, secretName, nil) + resp, err := k.client.Get(ctx, rgScope.ResourceGroup, vaultName, secretName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, k.Type()) } @@ -71,11 +70,11 @@ func (k keyvaultSecretWrapper) Search(ctx context.Context, scope string, queryPa return nil, azureshared.QueryError(errors.New("vaultName cannot be empty"), scope, k.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = k.ResourceGroup() + rgScope, err := k.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, k.Type()) } - pager := k.client.NewListPager(resourceGroup, vaultName, nil) + pager := k.client.NewListPager(rgScope.ResourceGroup, vaultName, nil) var items []*sdp.Item for pager.More() { diff --git a/sources/azure/manual/keyvault-secret_test.go b/sources/azure/manual/keyvault-secret_test.go index c89720c0..65a69ae4 100644 --- a/sources/azure/manual/keyvault-secret_test.go +++ b/sources/azure/manual/keyvault-secret_test.go @@ -83,7 +83,7 @@ func TestKeyVaultSecret(t *testing.T) { Secret: *secret, }, nil) - wrapper := manual.NewKeyVaultSecret(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Get requires vaultName and secretName as query parts @@ -159,7 +159,7 @@ func TestKeyVaultSecret(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockSecretsClient(ctrl) - wrapper := manual.NewKeyVaultSecret(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with insufficient query parts (only vault name) @@ -172,7 +172,7 @@ func TestKeyVaultSecret(t *testing.T) { t.Run("Get_EmptyVaultName", func(t *testing.T) { mockClient := mocks.NewMockSecretsClient(ctrl) - wrapper := manual.NewKeyVaultSecret(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty vault name @@ -186,7 +186,7 @@ func TestKeyVaultSecret(t *testing.T) { t.Run("Get_EmptySecretName", func(t *testing.T) { mockClient := mocks.NewMockSecretsClient(ctrl) - wrapper := manual.NewKeyVaultSecret(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty secret name @@ -208,7 +208,7 @@ func TestKeyVaultSecret(t *testing.T) { Secret: *secret, }, nil) - wrapper := manual.NewKeyVaultSecret(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) query := vaultName + shared.QuerySeparator + secretName @@ -227,7 +227,7 @@ func TestKeyVaultSecret(t *testing.T) { Secret: *secret, }, nil) - wrapper := manual.NewKeyVaultSecret(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) query := vaultName + shared.QuerySeparator + secretName @@ -279,7 +279,7 @@ func TestKeyVaultSecret(t *testing.T) { pager: mockPager, } - wrapper := manual.NewKeyVaultSecret(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -311,7 +311,7 @@ func TestKeyVaultSecret(t *testing.T) { mockClient := mocks.NewMockSecretsClient(ctrl) testClient := &testSecretsClient{MockSecretsClient: mockClient} - wrapper := manual.NewKeyVaultSecret(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Test Search directly with no query parts - should return error before calling List _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) @@ -324,7 +324,7 @@ func TestKeyVaultSecret(t *testing.T) { mockClient := mocks.NewMockSecretsClient(ctrl) testClient := &testSecretsClient{MockSecretsClient: mockClient} - wrapper := manual.NewKeyVaultSecret(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Test Search directly with empty vault name _, qErr := wrapper.Search(ctx, "") @@ -365,7 +365,7 @@ func TestKeyVaultSecret(t *testing.T) { pager: mockPager, } - wrapper := manual.NewKeyVaultSecret(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -396,7 +396,7 @@ func TestKeyVaultSecret(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, vaultName, "nonexistent-secret", nil).Return( armkeyvault.SecretsClientGetResponse{}, expectedErr) - wrapper := manual.NewKeyVaultSecret(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) query := vaultName + shared.QuerySeparator + "nonexistent-secret" @@ -418,7 +418,7 @@ func TestKeyVaultSecret(t *testing.T) { pager: errorPager, } - wrapper := manual.NewKeyVaultSecret(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -437,7 +437,7 @@ func TestKeyVaultSecret(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockSecretsClient(ctrl) testClient := &testSecretsClient{MockSecretsClient: mockClient} - wrapper := manual.NewKeyVaultSecret(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Verify wrapper implements SearchableWrapper (it's returned as this type) if wrapper == nil { @@ -455,7 +455,7 @@ func TestKeyVaultSecret(t *testing.T) { t.Run("PotentialLinks", func(t *testing.T) { mockClient := mocks.NewMockSecretsClient(ctrl) testClient := &testSecretsClient{MockSecretsClient: mockClient} - wrapper := manual.NewKeyVaultSecret(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) links := wrapper.PotentialLinks() if len(links) == 0 { @@ -478,7 +478,7 @@ func TestKeyVaultSecret(t *testing.T) { t.Run("TerraformMappings", func(t *testing.T) { mockClient := mocks.NewMockSecretsClient(ctrl) testClient := &testSecretsClient{MockSecretsClient: mockClient} - wrapper := manual.NewKeyVaultSecret(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) mappings := wrapper.TerraformMappings() if len(mappings) == 0 { @@ -509,7 +509,7 @@ func TestKeyVaultSecret(t *testing.T) { t.Run("IAMPermissions", func(t *testing.T) { mockClient := mocks.NewMockSecretsClient(ctrl) testClient := &testSecretsClient{MockSecretsClient: mockClient} - wrapper := manual.NewKeyVaultSecret(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) permissions := wrapper.IAMPermissions() if len(permissions) == 0 { @@ -532,7 +532,7 @@ func TestKeyVaultSecret(t *testing.T) { t.Run("PredefinedRole", func(t *testing.T) { mockClient := mocks.NewMockSecretsClient(ctrl) testClient := &testSecretsClient{MockSecretsClient: mockClient} - wrapper := manual.NewKeyVaultSecret(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // PredefinedRole is available on the wrapper, not the interface // Use type assertion to access the concrete type @@ -562,7 +562,7 @@ func TestKeyVaultSecret(t *testing.T) { Secret: *secret, }, nil) - wrapper := manual.NewKeyVaultSecret(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultSecret(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) query := vaultName + shared.QuerySeparator + secretName diff --git a/sources/azure/manual/keyvault-vault.go b/sources/azure/manual/keyvault-vault.go index 3e1764d7..5116777c 100644 --- a/sources/azure/manual/keyvault-vault.go +++ b/sources/azure/manual/keyvault-vault.go @@ -6,7 +6,9 @@ import ( "fmt" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault/v2" + "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" "github.com/overmindtech/cli/sources" "github.com/overmindtech/cli/sources/azure/clients" azureshared "github.com/overmindtech/cli/sources/azure/shared" @@ -19,15 +21,14 @@ var KeyVaultVaultLookupByName = shared.NewItemTypeLookup("name", azureshared.Key type keyvaultVaultWrapper struct { client clients.VaultsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewKeyVaultVault(client clients.VaultsClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewKeyVaultVault(client clients.VaultsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &keyvaultVaultWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_SECURITY, azureshared.KeyVaultVault, ), @@ -35,11 +36,11 @@ func NewKeyVaultVault(client clients.VaultsClient, subscriptionID, resourceGroup } func (k keyvaultVaultWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = k.ResourceGroup() + rgScope, err := k.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, k.Type()) } - pager := k.client.NewListByResourceGroupPager(resourceGroup, nil) + pager := k.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -64,6 +65,37 @@ func (k keyvaultVaultWrapper) List(ctx context.Context, scope string) ([]*sdp.It return items, nil } + +func (k keyvaultVaultWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := k.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, k.Type())) + return + } + pager := k.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) + + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, k.Type())) + return + } + + for _, vault := range page.Value { + if vault.Name == nil { + continue + } + item, sdpErr := k.azureKeyVaultToSDPItem(vault, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + func (k keyvaultVaultWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { if len(queryParts) < 1 { return nil, azureshared.QueryError(errors.New("Get requires 1 query part: vaultName"), scope, k.Type()) @@ -74,11 +106,11 @@ func (k keyvaultVaultWrapper) Get(ctx context.Context, scope string, queryParts return nil, azureshared.QueryError(errors.New("vaultName cannot be empty"), scope, k.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = k.ResourceGroup() + rgScope, err := k.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, k.Type()) } - resp, err := k.client.Get(ctx, resourceGroup, vaultName, nil) + resp, err := k.client.Get(ctx, rgScope.ResourceGroup, vaultName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, k.Type()) } diff --git a/sources/azure/manual/keyvault-vault_test.go b/sources/azure/manual/keyvault-vault_test.go index 2decc20b..ce440e9c 100644 --- a/sources/azure/manual/keyvault-vault_test.go +++ b/sources/azure/manual/keyvault-vault_test.go @@ -81,7 +81,7 @@ func TestKeyVaultVault(t *testing.T) { Vault: *vault, }, nil) - wrapper := manual.NewKeyVaultVault(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultVault(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], vaultName, true) @@ -204,7 +204,7 @@ func TestKeyVaultVault(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockVaultsClient(ctrl) - wrapper := manual.NewKeyVaultVault(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultVault(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty name @@ -228,7 +228,7 @@ func TestKeyVaultVault(t *testing.T) { Vault: *vault, }, nil) - wrapper := manual.NewKeyVaultVault(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultVault(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], vaultName, true) @@ -246,7 +246,7 @@ func TestKeyVaultVault(t *testing.T) { Vault: *vault, }, nil) - wrapper := manual.NewKeyVaultVault(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultVault(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], vaultName, true) @@ -282,7 +282,7 @@ func TestKeyVaultVault(t *testing.T) { pager: mockPager, } - wrapper := manual.NewKeyVaultVault(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultVault(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -321,7 +321,7 @@ func TestKeyVaultVault(t *testing.T) { pager: errorPager, } - wrapper := manual.NewKeyVaultVault(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultVault(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -341,7 +341,7 @@ func TestKeyVaultVault(t *testing.T) { armkeyvault.VaultsClientGetResponse{}, errors.New("client error")) - wrapper := manual.NewKeyVaultVault(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultVault(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], vaultName, true) @@ -360,7 +360,7 @@ func TestKeyVaultVault(t *testing.T) { Vault: *vault, }, nil) - wrapper := manual.NewKeyVaultVault(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultVault(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], vaultName, true) @@ -388,7 +388,7 @@ func TestKeyVaultVault(t *testing.T) { t.Run("PotentialLinks", func(t *testing.T) { mockClient := mocks.NewMockVaultsClient(ctrl) - wrapper := manual.NewKeyVaultVault(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewKeyVaultVault(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) links := wrapper.PotentialLinks() if len(links) == 0 { diff --git a/sources/azure/manual/managedidentity-user-assigned-identity.go b/sources/azure/manual/managedidentity-user-assigned-identity.go index b7f231d0..c8e3258f 100644 --- a/sources/azure/manual/managedidentity-user-assigned-identity.go +++ b/sources/azure/manual/managedidentity-user-assigned-identity.go @@ -19,15 +19,14 @@ var ManagedIdentityUserAssignedIdentityLookupByName = shared.NewItemTypeLookup(" type managedIdentityUserAssignedIdentityWrapper struct { client clients.UserAssignedIdentitiesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewManagedIdentityUserAssignedIdentity(client clients.UserAssignedIdentitiesClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewManagedIdentityUserAssignedIdentity(client clients.UserAssignedIdentitiesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &managedIdentityUserAssignedIdentityWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_SECURITY, azureshared.ManagedIdentityUserAssignedIdentity, ), @@ -35,11 +34,11 @@ func NewManagedIdentityUserAssignedIdentity(client clients.UserAssignedIdentitie } func (m managedIdentityUserAssignedIdentityWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = m.ResourceGroup() + rgScope, err := m.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, m.Type()) } - pager := m.client.ListByResourceGroup(resourceGroup, nil) + pager := m.client.ListByResourceGroup(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -62,11 +61,12 @@ func (m managedIdentityUserAssignedIdentityWrapper) List(ctx context.Context, sc } func (m managedIdentityUserAssignedIdentityWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = m.ResourceGroup() + rgScope, err := m.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, m.Type())) + return } - pager := m.client.ListByResourceGroup(resourceGroup, nil) + pager := m.client.ListByResourceGroup(rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) if err != nil { @@ -136,11 +136,11 @@ func (m managedIdentityUserAssignedIdentityWrapper) Get(ctx context.Context, sco if name == "" { return nil, azureshared.QueryError(errors.New("user assigned identity name cannot be empty"), scope, m.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = m.ResourceGroup() + rgScope, err := m.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, m.Type()) } - identity, err := m.client.Get(ctx, resourceGroup, name, nil) + identity, err := m.client.Get(ctx, rgScope.ResourceGroup, name, nil) if err != nil { return nil, azureshared.QueryError(err, scope, m.Type()) } diff --git a/sources/azure/manual/managedidentity-user-assigned-identity_test.go b/sources/azure/manual/managedidentity-user-assigned-identity_test.go index a1813d84..67716545 100644 --- a/sources/azure/manual/managedidentity-user-assigned-identity_test.go +++ b/sources/azure/manual/managedidentity-user-assigned-identity_test.go @@ -39,7 +39,7 @@ func TestManagedIdentityUserAssignedIdentity(t *testing.T) { Identity: *identity, }, nil) - wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], identityName, true) @@ -85,7 +85,7 @@ func TestManagedIdentityUserAssignedIdentity(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockUserAssignedIdentitiesClient(ctrl) - wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty name @@ -104,7 +104,7 @@ func TestManagedIdentityUserAssignedIdentity(t *testing.T) { mockClient.EXPECT().ListByResourceGroup(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -157,7 +157,7 @@ func TestManagedIdentityUserAssignedIdentity(t *testing.T) { mockClient.EXPECT().ListByResourceGroup(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -189,7 +189,7 @@ func TestManagedIdentityUserAssignedIdentity(t *testing.T) { mockClient.EXPECT().ListByResourceGroup(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) wg := &sync.WaitGroup{} @@ -240,7 +240,7 @@ func TestManagedIdentityUserAssignedIdentity(t *testing.T) { mockClient.EXPECT().ListByResourceGroup(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) var errs []error @@ -269,7 +269,7 @@ func TestManagedIdentityUserAssignedIdentity(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-identity", nil).Return( armmsi.UserAssignedIdentitiesClientGetResponse{}, expectedErr) - wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-identity", true) @@ -286,7 +286,7 @@ func TestManagedIdentityUserAssignedIdentity(t *testing.T) { mockClient.EXPECT().ListByResourceGroup(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewManagedIdentityUserAssignedIdentity(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) diff --git a/sources/azure/manual/network-application-gateway.go b/sources/azure/manual/network-application-gateway.go index c5955b16..76288cd9 100644 --- a/sources/azure/manual/network-application-gateway.go +++ b/sources/azure/manual/network-application-gateway.go @@ -20,15 +20,14 @@ var NetworkApplicationGatewayLookupByName = shared.NewItemTypeLookup("name", azu type networkApplicationGatewayWrapper struct { client clients.ApplicationGatewaysClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewNetworkApplicationGateway(client clients.ApplicationGatewaysClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewNetworkApplicationGateway(client clients.ApplicationGatewaysClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &networkApplicationGatewayWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, azureshared.NetworkApplicationGateway, ), @@ -36,11 +35,11 @@ func NewNetworkApplicationGateway(client clients.ApplicationGatewaysClient, subs } func (n networkApplicationGatewayWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - pager := n.client.List(resourceGroup, nil) + pager := n.client.List(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -63,11 +62,12 @@ func (n networkApplicationGatewayWrapper) List(ctx context.Context, scope string } func (n networkApplicationGatewayWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return } - pager := n.client.List(resourceGroup, nil) + pager := n.client.List(rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) if err != nil { @@ -776,11 +776,11 @@ func (n networkApplicationGatewayWrapper) Get(ctx context.Context, scope string, if applicationGatewayName == "" { return nil, azureshared.QueryError(errors.New("application gateway name cannot be empty"), n.DefaultScope(), n.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - resp, err := n.client.Get(ctx, resourceGroup, applicationGatewayName, nil) + resp, err := n.client.Get(ctx, rgScope.ResourceGroup, applicationGatewayName, nil) if err != nil { return nil, azureshared.QueryError(err, n.DefaultScope(), n.Type()) } diff --git a/sources/azure/manual/network-application-gateway_test.go b/sources/azure/manual/network-application-gateway_test.go index 97de1bbe..53afb962 100644 --- a/sources/azure/manual/network-application-gateway_test.go +++ b/sources/azure/manual/network-application-gateway_test.go @@ -40,7 +40,7 @@ func TestNetworkApplicationGateway(t *testing.T) { ApplicationGateway: *applicationGateway, }, nil) - wrapper := manual.NewNetworkApplicationGateway(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkApplicationGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], agName, true) @@ -339,7 +339,7 @@ func TestNetworkApplicationGateway(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockApplicationGatewaysClient(ctrl) - wrapper := manual.NewNetworkApplicationGateway(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkApplicationGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Test with wrong number of query parts - need to call through the wrapper directly _, qErr := wrapper.Get(ctx, wrapper.Scopes()[0], "part1", "part2") @@ -351,7 +351,7 @@ func TestNetworkApplicationGateway(t *testing.T) { t.Run("Get_EmptyName", func(t *testing.T) { mockClient := mocks.NewMockApplicationGatewaysClient(ctrl) - wrapper := manual.NewNetworkApplicationGateway(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkApplicationGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty string name - validation happens before client.Get is called @@ -377,7 +377,7 @@ func TestNetworkApplicationGateway(t *testing.T) { ApplicationGateway: *applicationGateway, }, nil) - wrapper := manual.NewNetworkApplicationGateway(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkApplicationGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "test-ag", true) @@ -391,7 +391,7 @@ func TestNetworkApplicationGateway(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "test-ag", nil).Return( armnetwork.ApplicationGatewaysClientGetResponse{}, errors.New("not found")) - wrapper := manual.NewNetworkApplicationGateway(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkApplicationGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "test-ag", true) @@ -421,7 +421,7 @@ func TestNetworkApplicationGateway(t *testing.T) { mockClient.EXPECT().List(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkApplicationGateway(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkApplicationGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -480,7 +480,7 @@ func TestNetworkApplicationGateway(t *testing.T) { mockClient.EXPECT().List(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkApplicationGateway(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkApplicationGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -511,7 +511,7 @@ func TestNetworkApplicationGateway(t *testing.T) { mockClient.EXPECT().List(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkApplicationGateway(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkApplicationGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -535,7 +535,7 @@ func TestNetworkApplicationGateway(t *testing.T) { ApplicationGateway: *applicationGateway, }, nil) - wrapper := manual.NewNetworkApplicationGateway(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkApplicationGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], agName, true) @@ -562,7 +562,7 @@ func TestNetworkApplicationGateway(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockApplicationGatewaysClient(ctrl) - wrapper := manual.NewNetworkApplicationGateway(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkApplicationGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Verify adapter implements ListableAdapter interface diff --git a/sources/azure/manual/network-load-balancer.go b/sources/azure/manual/network-load-balancer.go index 6f3789ce..187b50fe 100644 --- a/sources/azure/manual/network-load-balancer.go +++ b/sources/azure/manual/network-load-balancer.go @@ -13,6 +13,8 @@ import ( azureshared "github.com/overmindtech/cli/sources/azure/shared" "github.com/overmindtech/cli/sources/shared" "github.com/overmindtech/cli/sources/stdlib" + "github.com/overmindtech/cli/sdpcache" + "github.com/overmindtech/cli/discovery" ) var NetworkLoadBalancerLookupByName = shared.NewItemTypeLookup("name", azureshared.NetworkLoadBalancer) @@ -20,15 +22,14 @@ var NetworkLoadBalancerLookupByName = shared.NewItemTypeLookup("name", azureshar type networkLoadBalancerWrapper struct { client clients.LoadBalancersClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewNetworkLoadBalancer(client clients.LoadBalancersClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewNetworkLoadBalancer(client clients.LoadBalancersClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &networkLoadBalancerWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, azureshared.NetworkLoadBalancer, ), @@ -36,11 +37,11 @@ func NewNetworkLoadBalancer(client clients.LoadBalancersClient, subscriptionID, } func (n networkLoadBalancerWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - pager := n.client.List(resourceGroup) + pager := n.client.List(rgScope.ResourceGroup) var items []*sdp.Item for pager.More() { @@ -65,6 +66,35 @@ func (n networkLoadBalancerWrapper) List(ctx context.Context, scope string) ([]* return items, nil } + +func (n networkLoadBalancerWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + pager := n.client.List(rgScope.ResourceGroup) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + for _, loadBalancer := range page.Value { + if loadBalancer.Name == nil { + continue + } + item, sdpErr := n.azureLoadBalancerToSDPItem(loadBalancer, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + func (n networkLoadBalancerWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { if len(queryParts) != 1 { return nil, azureshared.QueryError(errors.New("query must be a load balancer name"), scope, n.Type()) @@ -75,11 +105,11 @@ func (n networkLoadBalancerWrapper) Get(ctx context.Context, scope string, query return nil, azureshared.QueryError(errors.New("load balancer name cannot be empty"), scope, n.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - resp, err := n.client.Get(ctx, resourceGroup, loadBalancerName) + resp, err := n.client.Get(ctx, rgScope.ResourceGroup, loadBalancerName) if err != nil { return nil, azureshared.QueryError(err, scope, n.Type()) } diff --git a/sources/azure/manual/network-load-balancer_test.go b/sources/azure/manual/network-load-balancer_test.go index f4bda115..92d6b5cc 100644 --- a/sources/azure/manual/network-load-balancer_test.go +++ b/sources/azure/manual/network-load-balancer_test.go @@ -39,7 +39,7 @@ func TestNetworkLoadBalancer(t *testing.T) { LoadBalancer: *loadBalancer, }, nil) - wrapper := manual.NewNetworkLoadBalancer(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkLoadBalancer(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], lbName, true) @@ -195,7 +195,7 @@ func TestNetworkLoadBalancer(t *testing.T) { t.Run("Get_EmptyName", func(t *testing.T) { mockClient := mocks.NewMockLoadBalancersClient(ctrl) - wrapper := manual.NewNetworkLoadBalancer(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkLoadBalancer(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty string name - the wrapper validates this before calling the client @@ -227,7 +227,7 @@ func TestNetworkLoadBalancer(t *testing.T) { mockClient.EXPECT().List(resourceGroup).Return(mockPager) - wrapper := manual.NewNetworkLoadBalancer(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkLoadBalancer(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -288,7 +288,7 @@ func TestNetworkLoadBalancer(t *testing.T) { mockClient.EXPECT().List(resourceGroup).Return(mockPager) - wrapper := manual.NewNetworkLoadBalancer(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkLoadBalancer(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -318,7 +318,7 @@ func TestNetworkLoadBalancer(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-lb").Return( armnetwork.LoadBalancersClientGetResponse{}, expectedErr) - wrapper := manual.NewNetworkLoadBalancer(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkLoadBalancer(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-lb", true) @@ -342,7 +342,7 @@ func TestNetworkLoadBalancer(t *testing.T) { mockClient.EXPECT().List(resourceGroup).Return(mockPager) - wrapper := manual.NewNetworkLoadBalancer(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkLoadBalancer(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -358,7 +358,7 @@ func TestNetworkLoadBalancer(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockLoadBalancersClient(ctrl) - wrapper := manual.NewNetworkLoadBalancer(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkLoadBalancer(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Verify wrapper implements ListableWrapper interface var _ = wrapper @@ -457,7 +457,7 @@ func TestNetworkLoadBalancer(t *testing.T) { LoadBalancer: *loadBalancer, }, nil) - wrapper := manual.NewNetworkLoadBalancer(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkLoadBalancer(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], lbName, true) @@ -493,7 +493,7 @@ func TestNetworkLoadBalancer(t *testing.T) { LoadBalancer: *loadBalancer, }, nil) - wrapper := manual.NewNetworkLoadBalancer(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkLoadBalancer(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], lbName, true) diff --git a/sources/azure/manual/network-network-interface.go b/sources/azure/manual/network-network-interface.go index 273740a5..1e8fe0f9 100644 --- a/sources/azure/manual/network-network-interface.go +++ b/sources/azure/manual/network-network-interface.go @@ -11,6 +11,8 @@ import ( azureshared "github.com/overmindtech/cli/sources/azure/shared" "github.com/overmindtech/cli/sources/shared" "github.com/overmindtech/cli/sources/stdlib" + "github.com/overmindtech/cli/sdpcache" + "github.com/overmindtech/cli/discovery" ) var NetworkNetworkInterfaceLookupByName = shared.NewItemTypeLookup("name", azureshared.NetworkNetworkInterface) @@ -18,15 +20,14 @@ var NetworkNetworkInterfaceLookupByName = shared.NewItemTypeLookup("name", azure type networkNetworkInterfaceWrapper struct { client clients.NetworkInterfacesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewNetworkNetworkInterface(client clients.NetworkInterfacesClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewNetworkNetworkInterface(client clients.NetworkInterfacesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &networkNetworkInterfaceWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, azureshared.NetworkNetworkInterface, ), @@ -34,11 +35,11 @@ func NewNetworkNetworkInterface(client clients.NetworkInterfacesClient, subscrip } func (n networkNetworkInterfaceWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - pager := n.client.List(ctx, resourceGroup) + pager := n.client.List(ctx, rgScope.ResourceGroup) var items []*sdp.Item for pager.More() { @@ -59,6 +60,34 @@ func (n networkNetworkInterfaceWrapper) List(ctx context.Context, scope string) return items, nil } + +func (n networkNetworkInterfaceWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + pager := n.client.List(ctx, rgScope.ResourceGroup) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + for _, networkInterface := range page.Value { + if networkInterface.Name == nil { + continue + } + item, sdpErr := n.azureNetworkInterfaceToSDPItem(networkInterface) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} // reference: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/network-interfaces/get?view=rest-virtualnetwork-2025-03-01&tabs=HTTP#response func (n networkNetworkInterfaceWrapper) azureNetworkInterfaceToSDPItem(networkInterface *armnetwork.Interface) (*sdp.Item, *sdp.QueryError) { if networkInterface.Name == nil { @@ -500,11 +529,11 @@ func (n networkNetworkInterfaceWrapper) Get(ctx context.Context, scope string, q } networkInterfaceName := queryParts[0] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - networkInterface, err := n.client.Get(ctx, resourceGroup, networkInterfaceName) + networkInterface, err := n.client.Get(ctx, rgScope.ResourceGroup, networkInterfaceName) if err != nil { return nil, azureshared.QueryError(err, n.DefaultScope(), n.Type()) } diff --git a/sources/azure/manual/network-network-interface_test.go b/sources/azure/manual/network-network-interface_test.go index 4fc273bf..d491ac3e 100644 --- a/sources/azure/manual/network-network-interface_test.go +++ b/sources/azure/manual/network-network-interface_test.go @@ -39,7 +39,7 @@ func TestNetworkNetworkInterface(t *testing.T) { Interface: *nic, }, nil) - wrapper := manual.NewNetworkNetworkInterface(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkInterface(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], nicName, true) @@ -122,7 +122,7 @@ func TestNetworkNetworkInterface(t *testing.T) { Interface: *nicWithDNS, }, nil) - wrapper := manual.NewNetworkNetworkInterface(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkInterface(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], nicName, true) @@ -200,7 +200,7 @@ func TestNetworkNetworkInterface(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockNetworkInterfacesClient(ctrl) - wrapper := manual.NewNetworkNetworkInterface(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkInterface(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty string name - Get will still be called with empty string @@ -235,7 +235,7 @@ func TestNetworkNetworkInterface(t *testing.T) { mockClient.EXPECT().List(ctx, resourceGroup).Return(mockPager) - wrapper := manual.NewNetworkNetworkInterface(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkInterface(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -306,7 +306,7 @@ func TestNetworkNetworkInterface(t *testing.T) { mockClient.EXPECT().List(ctx, resourceGroup).Return(mockPager) - wrapper := manual.NewNetworkNetworkInterface(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkInterface(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -328,7 +328,7 @@ func TestNetworkNetworkInterface(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-nic").Return( armnetwork.InterfacesClientGetResponse{}, expectedErr) - wrapper := manual.NewNetworkNetworkInterface(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkInterface(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-nic", true) @@ -352,7 +352,7 @@ func TestNetworkNetworkInterface(t *testing.T) { mockClient.EXPECT().List(ctx, resourceGroup).Return(mockPager) - wrapper := manual.NewNetworkNetworkInterface(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkInterface(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -368,7 +368,7 @@ func TestNetworkNetworkInterface(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockNetworkInterfacesClient(ctrl) - wrapper := manual.NewNetworkNetworkInterface(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkInterface(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Verify wrapper implements ListableWrapper interface var _ = wrapper diff --git a/sources/azure/manual/network-network-security-group.go b/sources/azure/manual/network-network-security-group.go index e742affa..30e19904 100644 --- a/sources/azure/manual/network-network-security-group.go +++ b/sources/azure/manual/network-network-security-group.go @@ -35,15 +35,14 @@ func appendIPOrCIDRLinkIfValid(queries *[]*sdp.LinkedItemQuery, prefix string) { type networkNetworkSecurityGroupWrapper struct { client clients.NetworkSecurityGroupsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewNetworkNetworkSecurityGroup(client clients.NetworkSecurityGroupsClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewNetworkNetworkSecurityGroup(client clients.NetworkSecurityGroupsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &networkNetworkSecurityGroupWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, azureshared.NetworkNetworkSecurityGroup, ), @@ -52,11 +51,11 @@ func NewNetworkNetworkSecurityGroup(client clients.NetworkSecurityGroupsClient, // reference: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/network-security-groups/list?view=rest-virtualnetwork-2025-03-01&tabs=HTTP func (n networkNetworkSecurityGroupWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - pager := n.client.List(ctx, resourceGroup, nil) + pager := n.client.List(ctx, rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -79,11 +78,12 @@ func (n networkNetworkSecurityGroupWrapper) List(ctx context.Context, scope stri } func (n networkNetworkSecurityGroupWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return } - pager := n.client.List(ctx, resourceGroup, nil) + pager := n.client.List(ctx, rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) if err != nil { @@ -112,11 +112,11 @@ func (n networkNetworkSecurityGroupWrapper) Get(ctx context.Context, scope strin } networkSecurityGroupName := queryParts[0] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - networkSecurityGroup, err := n.client.Get(ctx, resourceGroup, networkSecurityGroupName, nil) + networkSecurityGroup, err := n.client.Get(ctx, rgScope.ResourceGroup, networkSecurityGroupName, nil) if err != nil { return nil, azureshared.QueryError(err, n.DefaultScope(), n.Type()) } diff --git a/sources/azure/manual/network-network-security-group_test.go b/sources/azure/manual/network-network-security-group_test.go index 84f09695..559c196a 100644 --- a/sources/azure/manual/network-network-security-group_test.go +++ b/sources/azure/manual/network-network-security-group_test.go @@ -39,7 +39,7 @@ func TestNetworkNetworkSecurityGroup(t *testing.T) { SecurityGroup: *nsg, }, nil) - wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], nsgName, true) @@ -151,7 +151,7 @@ func TestNetworkNetworkSecurityGroup(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockNetworkSecurityGroupsClient(ctrl) - wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty string name - Get will still be called with empty string @@ -180,7 +180,7 @@ func TestNetworkNetworkSecurityGroup(t *testing.T) { SecurityGroup: *nsg, }, nil) - wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "test-nsg", true) @@ -210,7 +210,7 @@ func TestNetworkNetworkSecurityGroup(t *testing.T) { mockClient.EXPECT().List(ctx, resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -270,7 +270,7 @@ func TestNetworkNetworkSecurityGroup(t *testing.T) { mockClient.EXPECT().List(ctx, resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -304,7 +304,7 @@ func TestNetworkNetworkSecurityGroup(t *testing.T) { mockClient.EXPECT().List(ctx, resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -325,7 +325,7 @@ func TestNetworkNetworkSecurityGroup(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-nsg", nil).Return( armnetwork.SecurityGroupsClientGetResponse{}, expectedErr) - wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-nsg", true) @@ -366,7 +366,7 @@ func TestNetworkNetworkSecurityGroup(t *testing.T) { SecurityGroup: *nsg, }, nil) - wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], nsgName, true) @@ -407,7 +407,7 @@ func TestNetworkNetworkSecurityGroup(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockNetworkSecurityGroupsClient(ctrl) - wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkNetworkSecurityGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Verify wrapper implements ListableWrapper interface var _ = wrapper diff --git a/sources/azure/manual/network-public-ip-address.go b/sources/azure/manual/network-public-ip-address.go index a4128d15..f06ed5ff 100644 --- a/sources/azure/manual/network-public-ip-address.go +++ b/sources/azure/manual/network-public-ip-address.go @@ -6,6 +6,8 @@ import ( "strings" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v8" + "github.com/overmindtech/cli/sdpcache" + "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" "github.com/overmindtech/cli/sources" "github.com/overmindtech/cli/sources/azure/clients" @@ -19,15 +21,14 @@ var NetworkPublicIPAddressLookupByName = shared.NewItemTypeLookup("name", azures type networkPublicIPAddressWrapper struct { client clients.PublicIPAddressesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewNetworkPublicIPAddress(client clients.PublicIPAddressesClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewNetworkPublicIPAddress(client clients.PublicIPAddressesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &networkPublicIPAddressWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, azureshared.NetworkPublicIPAddress, ), @@ -37,11 +38,11 @@ func NewNetworkPublicIPAddress(client clients.PublicIPAddressesClient, subscript // reference: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/public-ip-addresses/list?view=rest-virtualnetwork-2025-03-01&tabs=HTTP // GET https://management.azure.com/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.Network/publicIPAddresses?api-version=2025-03-01 func (n networkPublicIPAddressWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - pager := n.client.List(ctx, resourceGroup) + pager := n.client.List(ctx, rgScope.ResourceGroup) var items []*sdp.Item for pager.More() { @@ -66,6 +67,33 @@ func (n networkPublicIPAddressWrapper) List(ctx context.Context, scope string) ( return items, nil } +func (n networkPublicIPAddressWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + pager := n.client.List(ctx, rgScope.ResourceGroup) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + for _, publicIPAddress := range page.Value { + if publicIPAddress.Name == nil { + continue + } + item, sdpErr := n.azurePublicIPAddressToSDPItem(publicIPAddress, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} // reference: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/public-ip-addresses/get?view=rest-virtualnetwork-2025-03-01&tabs=HTTP // GET https://management.azure.com/subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/Microsoft.Network/publicIPAddresses/{publicIpAddressName}?api-version=2025-03-01 func (n networkPublicIPAddressWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { @@ -75,11 +103,11 @@ func (n networkPublicIPAddressWrapper) Get(ctx context.Context, scope string, qu publicIPAddressName := queryParts[0] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - publicIPAddress, err := n.client.Get(ctx, resourceGroup, publicIPAddressName) + publicIPAddress, err := n.client.Get(ctx, rgScope.ResourceGroup, publicIPAddressName) if err != nil { return nil, azureshared.QueryError(err, scope, n.Type()) } diff --git a/sources/azure/manual/network-public-ip-address_test.go b/sources/azure/manual/network-public-ip-address_test.go index 2bf7c664..6da3eeff 100644 --- a/sources/azure/manual/network-public-ip-address_test.go +++ b/sources/azure/manual/network-public-ip-address_test.go @@ -39,7 +39,7 @@ func TestNetworkPublicIPAddress(t *testing.T) { PublicIPAddress: *publicIP, }, nil) - wrapper := manual.NewNetworkPublicIPAddress(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkPublicIPAddress(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], publicIPName, true) @@ -137,7 +137,7 @@ func TestNetworkPublicIPAddress(t *testing.T) { PublicIPAddress: *publicIP, }, nil) - wrapper := manual.NewNetworkPublicIPAddress(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkPublicIPAddress(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], publicIPName, true) @@ -170,7 +170,7 @@ func TestNetworkPublicIPAddress(t *testing.T) { PublicIPAddress: *publicIP, }, nil) - wrapper := manual.NewNetworkPublicIPAddress(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkPublicIPAddress(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], publicIPName, true) @@ -206,7 +206,7 @@ func TestNetworkPublicIPAddress(t *testing.T) { PublicIPAddress: *publicIP, }, nil) - wrapper := manual.NewNetworkPublicIPAddress(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkPublicIPAddress(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], publicIPName, true) @@ -234,7 +234,7 @@ func TestNetworkPublicIPAddress(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockPublicIPAddressesClient(ctrl) - wrapper := manual.NewNetworkPublicIPAddress(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkPublicIPAddress(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty name - Get will still be called with empty string @@ -269,7 +269,7 @@ func TestNetworkPublicIPAddress(t *testing.T) { mockClient.EXPECT().List(ctx, resourceGroup).Return(mockPager) - wrapper := manual.NewNetworkPublicIPAddress(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkPublicIPAddress(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -332,7 +332,7 @@ func TestNetworkPublicIPAddress(t *testing.T) { mockClient.EXPECT().List(ctx, resourceGroup).Return(mockPager) - wrapper := manual.NewNetworkPublicIPAddress(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkPublicIPAddress(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -361,7 +361,7 @@ func TestNetworkPublicIPAddress(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-ip").Return( armnetwork.PublicIPAddressesClientGetResponse{}, expectedErr) - wrapper := manual.NewNetworkPublicIPAddress(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkPublicIPAddress(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-ip", true) @@ -385,7 +385,7 @@ func TestNetworkPublicIPAddress(t *testing.T) { mockClient.EXPECT().List(ctx, resourceGroup).Return(mockPager) - wrapper := manual.NewNetworkPublicIPAddress(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkPublicIPAddress(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -401,7 +401,7 @@ func TestNetworkPublicIPAddress(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockPublicIPAddressesClient(ctrl) - wrapper := manual.NewNetworkPublicIPAddress(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkPublicIPAddress(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Verify wrapper implements ListableWrapper interface var _ = wrapper diff --git a/sources/azure/manual/network-route-table.go b/sources/azure/manual/network-route-table.go index 644bb2a4..0194ef5a 100644 --- a/sources/azure/manual/network-route-table.go +++ b/sources/azure/manual/network-route-table.go @@ -11,6 +11,8 @@ import ( azureshared "github.com/overmindtech/cli/sources/azure/shared" "github.com/overmindtech/cli/sources/shared" "github.com/overmindtech/cli/sources/stdlib" + "github.com/overmindtech/cli/sdpcache" + "github.com/overmindtech/cli/discovery" ) var NetworkRouteTableLookupByName = shared.NewItemTypeLookup("name", azureshared.NetworkRouteTable) @@ -18,14 +20,13 @@ var NetworkRouteTableLookupByName = shared.NewItemTypeLookup("name", azureshared type networkRouteTableWrapper struct { client clients.RouteTablesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewNetworkRouteTable(client clients.RouteTablesClient, subscriptionID, resourceGroup string) *networkRouteTableWrapper { +func NewNetworkRouteTable(client clients.RouteTablesClient, resourceGroupScopes []azureshared.ResourceGroupScope) *networkRouteTableWrapper { return &networkRouteTableWrapper{ - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, azureshared.NetworkRouteTable, ), @@ -34,11 +35,11 @@ func NewNetworkRouteTable(client clients.RouteTablesClient, subscriptionID, reso } func (n networkRouteTableWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - pager := n.client.List(resourceGroup, nil) + pager := n.client.List(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -60,6 +61,34 @@ func (n networkRouteTableWrapper) List(ctx context.Context, scope string) ([]*sd return items, nil } +func (n networkRouteTableWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + pager := n.client.List(rgScope.ResourceGroup, nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + for _, routeTable := range page.Value { + if routeTable.Name == nil { + continue + } + item, sdpErr := n.azureRouteTableToSDPItem(routeTable) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + func (n networkRouteTableWrapper) azureRouteTableToSDPItem(routeTable *armnetwork.RouteTable) (*sdp.Item, *sdp.QueryError) { if routeTable.Name == nil { return nil, azureshared.QueryError(errors.New("route table name is nil"), n.DefaultScope(), n.Type()) @@ -172,11 +201,11 @@ func (n networkRouteTableWrapper) Get(ctx context.Context, scope string, queryPa if routeTableName == "" { return nil, azureshared.QueryError(errors.New("route table name is empty"), n.DefaultScope(), n.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - resp, err := n.client.Get(ctx, resourceGroup, routeTableName, nil) + resp, err := n.client.Get(ctx, rgScope.ResourceGroup, routeTableName, nil) if err != nil { return nil, azureshared.QueryError(err, n.DefaultScope(), n.Type()) } diff --git a/sources/azure/manual/network-route-table_test.go b/sources/azure/manual/network-route-table_test.go index e6dc9323..3dd23909 100644 --- a/sources/azure/manual/network-route-table_test.go +++ b/sources/azure/manual/network-route-table_test.go @@ -39,7 +39,7 @@ func TestNetworkRouteTable(t *testing.T) { RouteTable: *routeTable, }, nil) - wrapper := manual.NewNetworkRouteTable(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkRouteTable(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], routeTableName, true) @@ -107,7 +107,7 @@ func TestNetworkRouteTable(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockRouteTablesClient(ctrl) - wrapper := manual.NewNetworkRouteTable(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkRouteTable(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty string name - validation happens before client.Get is called @@ -133,7 +133,7 @@ func TestNetworkRouteTable(t *testing.T) { RouteTable: *routeTable, }, nil) - wrapper := manual.NewNetworkRouteTable(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkRouteTable(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "test-route-table", true) @@ -163,7 +163,7 @@ func TestNetworkRouteTable(t *testing.T) { mockClient.EXPECT().List(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkRouteTable(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkRouteTable(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -223,7 +223,7 @@ func TestNetworkRouteTable(t *testing.T) { mockClient.EXPECT().List(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkRouteTable(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkRouteTable(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -257,7 +257,7 @@ func TestNetworkRouteTable(t *testing.T) { mockClient.EXPECT().List(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkRouteTable(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkRouteTable(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -278,7 +278,7 @@ func TestNetworkRouteTable(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-route-table", nil).Return( armnetwork.RouteTablesClientGetResponse{}, expectedErr) - wrapper := manual.NewNetworkRouteTable(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkRouteTable(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-route-table", true) @@ -314,7 +314,7 @@ func TestNetworkRouteTable(t *testing.T) { RouteTable: *routeTable, }, nil) - wrapper := manual.NewNetworkRouteTable(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkRouteTable(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], routeTableName, true) @@ -367,7 +367,7 @@ func TestNetworkRouteTable(t *testing.T) { RouteTable: *routeTable, }, nil) - wrapper := manual.NewNetworkRouteTable(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkRouteTable(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], routeTableName, true) @@ -422,7 +422,7 @@ func TestNetworkRouteTable(t *testing.T) { RouteTable: *routeTable, }, nil) - wrapper := manual.NewNetworkRouteTable(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkRouteTable(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], routeTableName, true) @@ -440,7 +440,7 @@ func TestNetworkRouteTable(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockRouteTablesClient(ctrl) - wrapper := manual.NewNetworkRouteTable(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkRouteTable(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Verify wrapper implements ListableWrapper interface var _ sources.ListableWrapper = wrapper diff --git a/sources/azure/manual/network-virtual-network.go b/sources/azure/manual/network-virtual-network.go index 61580796..4cc91fac 100644 --- a/sources/azure/manual/network-virtual-network.go +++ b/sources/azure/manual/network-virtual-network.go @@ -11,6 +11,8 @@ import ( azureshared "github.com/overmindtech/cli/sources/azure/shared" "github.com/overmindtech/cli/sources/shared" "github.com/overmindtech/cli/sources/stdlib" + "github.com/overmindtech/cli/sdpcache" + "github.com/overmindtech/cli/discovery" ) var NetworkVirtualNetworkLookupByName = shared.NewItemTypeLookup("name", azureshared.NetworkVirtualNetwork) @@ -18,16 +20,15 @@ var NetworkVirtualNetworkLookupByName = shared.NewItemTypeLookup("name", azuresh type networkVirtualNetworkWrapper struct { client clients.VirtualNetworksClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } // NewNetworkVirtualNetwork creates a new networkVirtualNetworkWrapper instance -func NewNetworkVirtualNetwork(client clients.VirtualNetworksClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewNetworkVirtualNetwork(client clients.VirtualNetworksClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &networkVirtualNetworkWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, azureshared.NetworkVirtualNetwork, ), @@ -35,11 +36,11 @@ func NewNetworkVirtualNetwork(client clients.VirtualNetworksClient, subscription } func (n networkVirtualNetworkWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - pager := n.client.NewListPager(resourceGroup, nil) + pager := n.client.NewListPager(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -61,6 +62,35 @@ func (n networkVirtualNetworkWrapper) List(ctx context.Context, scope string) ([ return items, nil } +func (n networkVirtualNetworkWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + pager := n.client.NewListPager(rgScope.ResourceGroup, nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + + for _, network := range page.Value { + if network.Name == nil { + continue + } + item, sdpErr := n.azureVirtualNetworkToSDPItem(network, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + func (n networkVirtualNetworkWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { if len(queryParts) < 1 { return nil, &sdp.QueryError{ @@ -73,11 +103,11 @@ func (n networkVirtualNetworkWrapper) Get(ctx context.Context, scope string, que virtualNetworkName := queryParts[0] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - resp, err := n.client.Get(ctx, resourceGroup, virtualNetworkName, nil) + resp, err := n.client.Get(ctx, rgScope.ResourceGroup, virtualNetworkName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, n.Type()) } diff --git a/sources/azure/manual/network-virtual-network_test.go b/sources/azure/manual/network-virtual-network_test.go index b7c5bf15..e230a831 100644 --- a/sources/azure/manual/network-virtual-network_test.go +++ b/sources/azure/manual/network-virtual-network_test.go @@ -39,7 +39,7 @@ func TestNetworkVirtualNetwork(t *testing.T) { VirtualNetwork: *vnet, }, nil) - wrapper := manual.NewNetworkVirtualNetwork(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkVirtualNetwork(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], vnetName, true) @@ -103,7 +103,7 @@ func TestNetworkVirtualNetwork(t *testing.T) { VirtualNetwork: *vnet, }, nil) - wrapper := manual.NewNetworkVirtualNetwork(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkVirtualNetwork(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], vnetName, true) @@ -171,7 +171,7 @@ func TestNetworkVirtualNetwork(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockVirtualNetworksClient(ctrl) - wrapper := manual.NewNetworkVirtualNetwork(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkVirtualNetwork(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty string name - Get will still be called with empty string @@ -206,7 +206,7 @@ func TestNetworkVirtualNetwork(t *testing.T) { mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkVirtualNetwork(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkVirtualNetwork(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -272,7 +272,7 @@ func TestNetworkVirtualNetwork(t *testing.T) { mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkVirtualNetwork(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkVirtualNetwork(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -294,7 +294,7 @@ func TestNetworkVirtualNetwork(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-vnet", nil).Return( armnetwork.VirtualNetworksClientGetResponse{}, expectedErr) - wrapper := manual.NewNetworkVirtualNetwork(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkVirtualNetwork(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-vnet", true) @@ -318,7 +318,7 @@ func TestNetworkVirtualNetwork(t *testing.T) { mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkVirtualNetwork(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkVirtualNetwork(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -334,7 +334,7 @@ func TestNetworkVirtualNetwork(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockVirtualNetworksClient(ctrl) - wrapper := manual.NewNetworkVirtualNetwork(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkVirtualNetwork(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Verify wrapper implements ListableWrapper interface var _ = wrapper diff --git a/sources/azure/manual/network-zone.go b/sources/azure/manual/network-zone.go index 46be346b..aaec32ff 100644 --- a/sources/azure/manual/network-zone.go +++ b/sources/azure/manual/network-zone.go @@ -20,15 +20,14 @@ var NetworkZoneLookupByName = shared.NewItemTypeLookup("name", azureshared.Netwo type networkZoneWrapper struct { client clients.ZonesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewNetworkZone(client clients.ZonesClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewNetworkZone(client clients.ZonesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &networkZoneWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, azureshared.NetworkZone, ), @@ -36,11 +35,11 @@ func NewNetworkZone(client clients.ZonesClient, subscriptionID, resourceGroup st } func (n networkZoneWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - pager := n.client.NewListByResourceGroupPager(resourceGroup, nil) + pager := n.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -66,11 +65,12 @@ func (n networkZoneWrapper) List(ctx context.Context, scope string) ([]*sdp.Item // ref: https://learn.microsoft.com/en-us/rest/api/dns/zones/list-by-resource-group?view=rest-dns-2018-05-01&tabs=HTTP func (n networkZoneWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return } - pager := n.client.NewListByResourceGroupPager(resourceGroup, nil) + pager := n.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) if err != nil { @@ -245,11 +245,11 @@ func (n networkZoneWrapper) Get(ctx context.Context, scope string, queryParts .. } zoneName := queryParts[0] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = n.ResourceGroup() + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) } - zone, err := n.client.Get(ctx, resourceGroup, zoneName, nil) + zone, err := n.client.Get(ctx, rgScope.ResourceGroup, zoneName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, n.Type()) } diff --git a/sources/azure/manual/network-zone_test.go b/sources/azure/manual/network-zone_test.go index 36468110..e2c919de 100644 --- a/sources/azure/manual/network-zone_test.go +++ b/sources/azure/manual/network-zone_test.go @@ -40,7 +40,7 @@ func TestNetworkZone(t *testing.T) { Zone: *zone, }, nil) - wrapper := manual.NewNetworkZone(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkZone(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], zoneName, true) @@ -141,7 +141,7 @@ func TestNetworkZone(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockZonesClient(ctrl) - wrapper := manual.NewNetworkZone(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkZone(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with empty name - the client will be called but will return an error @@ -179,7 +179,7 @@ func TestNetworkZone(t *testing.T) { Zone: *zone, }, nil) - wrapper := manual.NewNetworkZone(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkZone(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], zoneName, true) @@ -224,7 +224,7 @@ func TestNetworkZone(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkZone(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkZone(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -285,7 +285,7 @@ func TestNetworkZone(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkZone(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkZone(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -315,7 +315,7 @@ func TestNetworkZone(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-zone", nil).Return( armdns.ZonesClientGetResponse{}, expectedErr) - wrapper := manual.NewNetworkZone(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkZone(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-zone", true) @@ -339,7 +339,7 @@ func TestNetworkZone(t *testing.T) { mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewNetworkZone(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkZone(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -355,7 +355,7 @@ func TestNetworkZone(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockZonesClient(ctrl) - wrapper := manual.NewNetworkZone(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkZone(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Verify wrapper implements ListableWrapper interface var _ = wrapper @@ -454,7 +454,7 @@ func TestNetworkZone(t *testing.T) { Zone: *zone, }, nil) - wrapper := manual.NewNetworkZone(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewNetworkZone(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], zoneName, true) diff --git a/sources/azure/manual/sql-database.go b/sources/azure/manual/sql-database.go index d073230f..34a1f32c 100644 --- a/sources/azure/manual/sql-database.go +++ b/sources/azure/manual/sql-database.go @@ -17,15 +17,14 @@ var SQLDatabaseLookupByName = shared.NewItemTypeLookup("name", azureshared.SQLDa type sqlDatabaseWrapper struct { client clients.SqlDatabasesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewSqlDatabase(client clients.SqlDatabasesClient, subscriptionID, resourceGroup string) sources.SearchableWrapper { +func NewSqlDatabase(client clients.SqlDatabasesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { return &sqlDatabaseWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_DATABASE, azureshared.SQLDatabase, ), @@ -44,11 +43,11 @@ func (s sqlDatabaseWrapper) Get(ctx context.Context, scope string, queryParts .. serverName := queryParts[0] databaseName := queryParts[1] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - resp, err := s.client.Get(ctx, resourceGroup, serverName, databaseName) + resp, err := s.client.Get(ctx, rgScope.ResourceGroup, serverName, databaseName) if err != nil { return nil, azureshared.QueryError(err, scope, s.Type()) } @@ -405,11 +404,11 @@ func (s sqlDatabaseWrapper) Search(ctx context.Context, scope string, queryParts } serverName := queryParts[0] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - pager := s.client.ListByServer(ctx, resourceGroup, serverName) + pager := s.client.ListByServer(ctx, rgScope.ResourceGroup, serverName) var items []*sdp.Item for pager.More() { diff --git a/sources/azure/manual/sql-database_test.go b/sources/azure/manual/sql-database_test.go index 48a61728..4e08fba1 100644 --- a/sources/azure/manual/sql-database_test.go +++ b/sources/azure/manual/sql-database_test.go @@ -80,7 +80,7 @@ func TestSqlDatabase(t *testing.T) { }, nil) testClient := &testSqlDatabasesClient{MockSqlDatabasesClient: mockClient} - wrapper := manual.NewSqlDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Get requires serverName and databaseName as query parts @@ -153,7 +153,7 @@ func TestSqlDatabase(t *testing.T) { }, nil) testClient := &testSqlDatabasesClient{MockSqlDatabasesClient: mockClient} - wrapper := manual.NewSqlDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) query := serverName + shared.QuerySeparator + databaseName @@ -207,7 +207,7 @@ func TestSqlDatabase(t *testing.T) { mockClient := mocks.NewMockSqlDatabasesClient(ctrl) testClient := &testSqlDatabasesClient{MockSqlDatabasesClient: mockClient} - wrapper := manual.NewSqlDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with insufficient query parts (only server name) @@ -237,7 +237,7 @@ func TestSqlDatabase(t *testing.T) { pager: mockPager, } - wrapper := manual.NewSqlDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -295,7 +295,7 @@ func TestSqlDatabase(t *testing.T) { pager: mockPager, } - wrapper := manual.NewSqlDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -322,7 +322,7 @@ func TestSqlDatabase(t *testing.T) { mockClient := mocks.NewMockSqlDatabasesClient(ctrl) testClient := &testSqlDatabasesClient{MockSqlDatabasesClient: mockClient} - wrapper := manual.NewSqlDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Test Search directly with no query parts - should return error before calling ListByServer _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) @@ -339,7 +339,7 @@ func TestSqlDatabase(t *testing.T) { armsql.DatabasesClientGetResponse{}, expectedErr) testClient := &testSqlDatabasesClient{MockSqlDatabasesClient: mockClient} - wrapper := manual.NewSqlDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) query := serverName + shared.QuerySeparator + "nonexistent-database" @@ -359,7 +359,7 @@ func TestSqlDatabase(t *testing.T) { pager: errorPager, } - wrapper := manual.NewSqlDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -377,7 +377,7 @@ func TestSqlDatabase(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockSqlDatabasesClient(ctrl) testClient := &testSqlDatabasesClient{MockSqlDatabasesClient: mockClient} - wrapper := manual.NewSqlDatabase(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlDatabase(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Cast to sources.Wrapper to access interface methods w := wrapper.(sources.Wrapper) diff --git a/sources/azure/manual/sql-server.go b/sources/azure/manual/sql-server.go index aea07b36..c6f7b3f0 100644 --- a/sources/azure/manual/sql-server.go +++ b/sources/azure/manual/sql-server.go @@ -17,12 +17,11 @@ import ( var SQLServerLookupByName = shared.NewItemTypeLookup("name", azureshared.SQLServer) -func NewSqlServer(client clients.SqlServersClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewSqlServer(client clients.SqlServersClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &sqlServerWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_DATABASE, azureshared.SQLServer, ), @@ -32,15 +31,15 @@ func NewSqlServer(client clients.SqlServersClient, subscriptionID, resourceGroup type sqlServerWrapper struct { client clients.SqlServersClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } func (s sqlServerWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - pager := s.client.ListByResourceGroup(ctx, resourceGroup, nil) + pager := s.client.ListByResourceGroup(ctx, rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -64,11 +63,12 @@ func (s sqlServerWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, } func (s sqlServerWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, s.Type())) + return } - pager := s.client.ListByResourceGroup(ctx, resourceGroup, nil) + pager := s.client.ListByResourceGroup(ctx, rgScope.ResourceGroup, nil) for pager.More() { page, err := pager.NextPage(ctx) @@ -100,11 +100,11 @@ func (s sqlServerWrapper) Get(ctx context.Context, scope string, queryParts ...s return nil, azureshared.QueryError(errors.New("serverName is empty"), scope, s.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - resp, err := s.client.Get(ctx, resourceGroup, serverName, nil) + resp, err := s.client.Get(ctx, rgScope.ResourceGroup, serverName, nil) if err != nil { return nil, azureshared.QueryError(err, scope, s.Type()) } diff --git a/sources/azure/manual/sql-server_test.go b/sources/azure/manual/sql-server_test.go index a04e58ef..255126f6 100644 --- a/sources/azure/manual/sql-server_test.go +++ b/sources/azure/manual/sql-server_test.go @@ -81,7 +81,7 @@ func TestSqlServer(t *testing.T) { }, nil) testClient := &testSqlServersClient{MockSqlServersClient: mockClient} - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) @@ -411,7 +411,7 @@ func TestSqlServer(t *testing.T) { }, nil) testClient := &testSqlServersClient{MockSqlServersClient: mockClient} - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) @@ -471,7 +471,7 @@ func TestSqlServer(t *testing.T) { }, nil) testClient := &testSqlServersClient{MockSqlServersClient: mockClient} - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) @@ -537,7 +537,7 @@ func TestSqlServer(t *testing.T) { }, nil) testClient := &testSqlServersClient{MockSqlServersClient: mockClient} - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) @@ -603,7 +603,7 @@ func TestSqlServer(t *testing.T) { }, nil) testClient := &testSqlServersClient{MockSqlServersClient: mockClient} - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) @@ -649,7 +649,7 @@ func TestSqlServer(t *testing.T) { }, nil) testClient := &testSqlServersClient{MockSqlServersClient: mockClient} - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) @@ -688,7 +688,7 @@ func TestSqlServer(t *testing.T) { }, nil) testClient := &testSqlServersClient{MockSqlServersClient: mockClient} - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) @@ -720,7 +720,7 @@ func TestSqlServer(t *testing.T) { mockClient := mocks.NewMockSqlServersClient(ctrl) testClient := &testSqlServersClient{MockSqlServersClient: mockClient} - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with insufficient query parts (no server name) @@ -750,7 +750,7 @@ func TestSqlServer(t *testing.T) { pager: mockPager, } - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -811,7 +811,7 @@ func TestSqlServer(t *testing.T) { pager: mockPager, } - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -854,7 +854,7 @@ func TestSqlServer(t *testing.T) { pager: mockPager, } - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) wg := &sync.WaitGroup{} @@ -905,7 +905,7 @@ func TestSqlServer(t *testing.T) { armsql.ServersClientGetResponse{}, expectedErr) testClient := &testSqlServersClient{MockSqlServersClient: mockClient} - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-server", true) @@ -924,7 +924,7 @@ func TestSqlServer(t *testing.T) { pager: errorPager, } - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -949,7 +949,7 @@ func TestSqlServer(t *testing.T) { pager: errorPager, } - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) var errs []error @@ -975,7 +975,7 @@ func TestSqlServer(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockSqlServersClient(ctrl) testClient := &testSqlServersClient{MockSqlServersClient: mockClient} - wrapper := manual.NewSqlServer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewSqlServer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Cast to sources.Wrapper to access interface methods w := wrapper.(sources.Wrapper) diff --git a/sources/azure/manual/storage-account.go b/sources/azure/manual/storage-account.go index f17ff6fa..b6465eb6 100644 --- a/sources/azure/manual/storage-account.go +++ b/sources/azure/manual/storage-account.go @@ -11,6 +11,8 @@ import ( azureshared "github.com/overmindtech/cli/sources/azure/shared" "github.com/overmindtech/cli/sources/shared" "github.com/overmindtech/cli/sources/stdlib" + "github.com/overmindtech/cli/discovery" + "github.com/overmindtech/cli/sdpcache" ) var StorageAccountLookupByName = shared.NewItemTypeLookup("name", azureshared.StorageAccount) @@ -18,15 +20,14 @@ var StorageAccountLookupByName = shared.NewItemTypeLookup("name", azureshared.St type storageAccountWrapper struct { client clients.StorageAccountsClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewStorageAccount(client clients.StorageAccountsClient, subscriptionID, resourceGroup string) sources.ListableWrapper { +func NewStorageAccount(client clients.StorageAccountsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { return &storageAccountWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_STORAGE, azureshared.StorageAccount, ), @@ -34,11 +35,11 @@ func NewStorageAccount(client clients.StorageAccountsClient, subscriptionID, res } func (s storageAccountWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - pager := s.client.List(resourceGroup) + pager := s.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) var items []*sdp.Item for pager.More() { @@ -63,6 +64,34 @@ func (s storageAccountWrapper) List(ctx context.Context, scope string) ([]*sdp.I return items, nil } +func (s storageAccountWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, s.Type())) + return + } + pager := s.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, s.Type())) + return + } + for _, account := range page.Value { + if account.Name == nil { + continue + } + item, sdpErr := s.azureStorageAccountToSDPItem(account, *account.Name, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } +} +} + func (s storageAccountWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { if len(queryParts) < 1 { return nil, &sdp.QueryError{ @@ -82,11 +111,11 @@ func (s storageAccountWrapper) Get(ctx context.Context, scope string, queryParts } } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - resp, err := s.client.Get(ctx, resourceGroup, accountName) + resp, err := s.client.Get(ctx, rgScope.ResourceGroup, accountName) if err != nil { return nil, azureshared.QueryError(err, scope, s.Type()) } diff --git a/sources/azure/manual/storage-account_test.go b/sources/azure/manual/storage-account_test.go index 02696cca..4c595eed 100644 --- a/sources/azure/manual/storage-account_test.go +++ b/sources/azure/manual/storage-account_test.go @@ -37,7 +37,7 @@ func TestStorageAccount(t *testing.T) { Account: *account, }, nil) - wrapper := manual.NewStorageAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], accountName, true) @@ -171,7 +171,7 @@ func TestStorageAccount(t *testing.T) { t.Run("Get_InvalidQueryParts", func(t *testing.T) { mockClient := mocks.NewMockStorageAccountsClient(ctrl) - wrapper := manual.NewStorageAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with insufficient query parts (empty) @@ -200,9 +200,9 @@ func TestStorageAccount(t *testing.T) { mockPager.EXPECT().More().Return(false), ) - mockClient.EXPECT().List(resourceGroup).Return(mockPager) + mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewStorageAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -263,9 +263,9 @@ func TestStorageAccount(t *testing.T) { mockPager.EXPECT().More().Return(false), ) - mockClient.EXPECT().List(resourceGroup).Return(mockPager) + mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewStorageAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) @@ -299,7 +299,7 @@ func TestStorageAccount(t *testing.T) { mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-account").Return( armstorage.AccountsClientGetPropertiesResponse{}, expectedErr) - wrapper := manual.NewStorageAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-account", true) @@ -321,9 +321,9 @@ func TestStorageAccount(t *testing.T) { armstorage.AccountsClientListByResourceGroupResponse{}, expectedErr), ) - mockClient.EXPECT().List(resourceGroup).Return(mockPager) + mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) - wrapper := manual.NewStorageAccount(mockClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageAccount(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) listable, ok := adapter.(discovery.ListableAdapter) diff --git a/sources/azure/manual/storage-blob-container.go b/sources/azure/manual/storage-blob-container.go index a70b7b6b..7c424c87 100644 --- a/sources/azure/manual/storage-blob-container.go +++ b/sources/azure/manual/storage-blob-container.go @@ -19,15 +19,14 @@ var StorageBlobContainerLookupByName = shared.NewItemTypeLookup("name", azuresha type storageBlobContainerWrapper struct { client clients.BlobContainersClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewStorageBlobContainer(client clients.BlobContainersClient, subscriptionID, resourceGroup string) sources.SearchableWrapper { +func NewStorageBlobContainer(client clients.BlobContainersClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { return &storageBlobContainerWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_STORAGE, azureshared.StorageBlobContainer, ), @@ -46,11 +45,11 @@ func (s storageBlobContainerWrapper) Get(ctx context.Context, scope string, quer storageAccountName := queryParts[0] containerName := queryParts[1] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - resp, err := s.client.Get(ctx, resourceGroup, storageAccountName, containerName) + resp, err := s.client.Get(ctx, rgScope.ResourceGroup, storageAccountName, containerName) if err != nil { return nil, azureshared.QueryError(err, scope, s.Type()) } @@ -73,11 +72,11 @@ func (s storageBlobContainerWrapper) Search(ctx context.Context, scope string, q if storageAccountName == "" { return nil, azureshared.QueryError(fmt.Errorf("storageAccountName cannot be empty"), scope, s.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - pager := s.client.List(ctx, resourceGroup, storageAccountName) + pager := s.client.List(ctx, rgScope.ResourceGroup, storageAccountName) var items []*sdp.Item for pager.More() { diff --git a/sources/azure/manual/storage-blob-container_test.go b/sources/azure/manual/storage-blob-container_test.go index 9e1e7316..9c83b4d0 100644 --- a/sources/azure/manual/storage-blob-container_test.go +++ b/sources/azure/manual/storage-blob-container_test.go @@ -81,7 +81,7 @@ func TestStorageBlobContainer(t *testing.T) { }, nil) testClient := &testBlobContainersClient{MockBlobContainersClient: mockClient} - wrapper := manual.NewStorageBlobContainer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageBlobContainer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Get requires storageAccountName and containerName as query parts @@ -185,7 +185,7 @@ func TestStorageBlobContainer(t *testing.T) { }, nil) testClient := &testBlobContainersClient{MockBlobContainersClient: mockClient} - wrapper := manual.NewStorageBlobContainer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageBlobContainer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) query := storageAccountName + shared.QuerySeparator + containerName @@ -231,7 +231,7 @@ func TestStorageBlobContainer(t *testing.T) { mockClient := mocks.NewMockBlobContainersClient(ctrl) testClient := &testBlobContainersClient{MockBlobContainersClient: mockClient} - wrapper := manual.NewStorageBlobContainer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageBlobContainer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with insufficient query parts (only storage account name) @@ -282,7 +282,7 @@ func TestStorageBlobContainer(t *testing.T) { pager: mockPager, } - wrapper := manual.NewStorageBlobContainer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageBlobContainer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -316,7 +316,7 @@ func TestStorageBlobContainer(t *testing.T) { mockClient := mocks.NewMockBlobContainersClient(ctrl) testClient := &testBlobContainersClient{MockBlobContainersClient: mockClient} - wrapper := manual.NewStorageBlobContainer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageBlobContainer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Test Search directly with no query parts - should return error before calling List _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) @@ -352,7 +352,7 @@ func TestStorageBlobContainer(t *testing.T) { pager: mockPager, } - wrapper := manual.NewStorageBlobContainer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageBlobContainer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -383,7 +383,7 @@ func TestStorageBlobContainer(t *testing.T) { armstorage.BlobContainersClientGetResponse{}, expectedErr) testClient := &testBlobContainersClient{MockBlobContainersClient: mockClient} - wrapper := manual.NewStorageBlobContainer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageBlobContainer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) query := storageAccountName + shared.QuerySeparator + "nonexistent-container" @@ -403,7 +403,7 @@ func TestStorageBlobContainer(t *testing.T) { pager: errorPager, } - wrapper := manual.NewStorageBlobContainer(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageBlobContainer(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) diff --git a/sources/azure/manual/storage-fileshare.go b/sources/azure/manual/storage-fileshare.go index 088b0888..ed3a3366 100644 --- a/sources/azure/manual/storage-fileshare.go +++ b/sources/azure/manual/storage-fileshare.go @@ -16,15 +16,14 @@ var StorageFileShareLookupByName = shared.NewItemTypeLookup("name", azureshared. type storageFileShareWrapper struct { client clients.FileSharesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewStorageFileShare(client clients.FileSharesClient, subscriptionID, resourceGroup string) sources.SearchableWrapper { +func NewStorageFileShare(client clients.FileSharesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { return &storageFileShareWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_STORAGE, azureshared.StorageFileShare, ), @@ -43,11 +42,11 @@ func (s storageFileShareWrapper) Get(ctx context.Context, scope string, queryPar storageAccountName := queryParts[0] shareName := queryParts[1] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - resp, err := s.client.Get(ctx, resourceGroup, storageAccountName, shareName) + resp, err := s.client.Get(ctx, rgScope.ResourceGroup, storageAccountName, shareName) if err != nil { return nil, azureshared.QueryError(err, scope, s.Type()) } @@ -80,11 +79,11 @@ func (s storageFileShareWrapper) Search(ctx context.Context, scope string, query } storageAccountName := queryParts[0] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - pager := s.client.List(ctx, resourceGroup, storageAccountName) + pager := s.client.List(ctx, rgScope.ResourceGroup, storageAccountName) var items []*sdp.Item for pager.More() { diff --git a/sources/azure/manual/storage-fileshare_test.go b/sources/azure/manual/storage-fileshare_test.go index a2a42557..ab0f5c91 100644 --- a/sources/azure/manual/storage-fileshare_test.go +++ b/sources/azure/manual/storage-fileshare_test.go @@ -80,7 +80,7 @@ func TestStorageFileShare(t *testing.T) { }, nil) testClient := &testFileSharesClient{MockFileSharesClient: mockClient} - wrapper := manual.NewStorageFileShare(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageFileShare(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Get requires storageAccountName and shareName as query parts @@ -141,7 +141,7 @@ func TestStorageFileShare(t *testing.T) { mockClient := mocks.NewMockFileSharesClient(ctrl) testClient := &testFileSharesClient{MockFileSharesClient: mockClient} - wrapper := manual.NewStorageFileShare(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageFileShare(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with insufficient query parts (only storage account name) @@ -192,7 +192,7 @@ func TestStorageFileShare(t *testing.T) { pager: mockPager, } - wrapper := manual.NewStorageFileShare(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageFileShare(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -226,7 +226,7 @@ func TestStorageFileShare(t *testing.T) { mockClient := mocks.NewMockFileSharesClient(ctrl) testClient := &testFileSharesClient{MockFileSharesClient: mockClient} - wrapper := manual.NewStorageFileShare(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageFileShare(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Test Search directly with no query parts - should return error before calling List _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) @@ -262,7 +262,7 @@ func TestStorageFileShare(t *testing.T) { pager: mockPager, } - wrapper := manual.NewStorageFileShare(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageFileShare(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -293,7 +293,7 @@ func TestStorageFileShare(t *testing.T) { armstorage.FileSharesClientGetResponse{}, expectedErr) testClient := &testFileSharesClient{MockFileSharesClient: mockClient} - wrapper := manual.NewStorageFileShare(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageFileShare(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) query := storageAccountName + shared.QuerySeparator + "nonexistent-share" @@ -313,7 +313,7 @@ func TestStorageFileShare(t *testing.T) { pager: errorPager, } - wrapper := manual.NewStorageFileShare(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageFileShare(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) diff --git a/sources/azure/manual/storage-queues.go b/sources/azure/manual/storage-queues.go index e339c2fa..f3121421 100644 --- a/sources/azure/manual/storage-queues.go +++ b/sources/azure/manual/storage-queues.go @@ -16,15 +16,14 @@ var StorageQueueLookupByName = shared.NewItemTypeLookup("name", azureshared.Stor type storageQueuesWrapper struct { client clients.QueuesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewStorageQueues(client clients.QueuesClient, subscriptionID, resourceGroup string) sources.SearchableWrapper { +func NewStorageQueues(client clients.QueuesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { return &storageQueuesWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_STORAGE, azureshared.StorageQueue, ), @@ -43,11 +42,11 @@ func (s storageQueuesWrapper) Get(ctx context.Context, scope string, queryParts storageAccountName := queryParts[0] queueName := queryParts[1] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - resp, err := s.client.Get(ctx, resourceGroup, storageAccountName, queueName) + resp, err := s.client.Get(ctx, rgScope.ResourceGroup, storageAccountName, queueName) if err != nil { return nil, azureshared.QueryError(err, scope, s.Type()) } @@ -119,11 +118,11 @@ func (s storageQueuesWrapper) Search(ctx context.Context, scope string, queryPar } storageAccountName := queryParts[0] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - pager := s.client.List(ctx, resourceGroup, storageAccountName) + pager := s.client.List(ctx, rgScope.ResourceGroup, storageAccountName) var items []*sdp.Item for pager.More() { diff --git a/sources/azure/manual/storage-queues_test.go b/sources/azure/manual/storage-queues_test.go index 49973f99..2e003bcc 100644 --- a/sources/azure/manual/storage-queues_test.go +++ b/sources/azure/manual/storage-queues_test.go @@ -80,7 +80,7 @@ func TestStorageQueues(t *testing.T) { }, nil) testClient := &testQueuesClient{MockQueuesClient: mockClient} - wrapper := manual.NewStorageQueues(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageQueues(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Get requires storageAccountName and queueName as query parts @@ -142,7 +142,7 @@ func TestStorageQueues(t *testing.T) { mockClient := mocks.NewMockQueuesClient(ctrl) testClient := &testQueuesClient{MockQueuesClient: mockClient} - wrapper := manual.NewStorageQueues(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageQueues(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with insufficient query parts (only storage account name) @@ -189,7 +189,7 @@ func TestStorageQueues(t *testing.T) { pager: mockPager, } - wrapper := manual.NewStorageQueues(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageQueues(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -223,7 +223,7 @@ func TestStorageQueues(t *testing.T) { mockClient := mocks.NewMockQueuesClient(ctrl) testClient := &testQueuesClient{MockQueuesClient: mockClient} - wrapper := manual.NewStorageQueues(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageQueues(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Test Search directly with no query parts - should return error before calling List _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) @@ -263,7 +263,7 @@ func TestStorageQueues(t *testing.T) { pager: mockPager, } - wrapper := manual.NewStorageQueues(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageQueues(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -295,7 +295,7 @@ func TestStorageQueues(t *testing.T) { armstorage.QueueClientGetResponse{}, expectedErr) testClient := &testQueuesClient{MockQueuesClient: mockClient} - wrapper := manual.NewStorageQueues(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageQueues(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) query := storageAccountName + shared.QuerySeparator + "nonexistent-queue" @@ -315,7 +315,7 @@ func TestStorageQueues(t *testing.T) { pager: errorPager, } - wrapper := manual.NewStorageQueues(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageQueues(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -334,7 +334,7 @@ func TestStorageQueues(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockQueuesClient(ctrl) testClient := &testQueuesClient{MockQueuesClient: mockClient} - wrapper := manual.NewStorageQueues(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageQueues(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Verify wrapper implements SearchableWrapper (it's returned as this type) if wrapper == nil { @@ -352,7 +352,7 @@ func TestStorageQueues(t *testing.T) { t.Run("PotentialLinks", func(t *testing.T) { mockClient := mocks.NewMockQueuesClient(ctrl) testClient := &testQueuesClient{MockQueuesClient: mockClient} - wrapper := manual.NewStorageQueues(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageQueues(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) links := wrapper.PotentialLinks() if len(links) == 0 { @@ -367,7 +367,7 @@ func TestStorageQueues(t *testing.T) { t.Run("TerraformMappings", func(t *testing.T) { mockClient := mocks.NewMockQueuesClient(ctrl) testClient := &testQueuesClient{MockQueuesClient: mockClient} - wrapper := manual.NewStorageQueues(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageQueues(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) mappings := wrapper.TerraformMappings() if len(mappings) == 0 { diff --git a/sources/azure/manual/storage-table.go b/sources/azure/manual/storage-table.go index ca1039ff..f8db9364 100644 --- a/sources/azure/manual/storage-table.go +++ b/sources/azure/manual/storage-table.go @@ -17,15 +17,14 @@ var StorageTableLookupByName = shared.NewItemTypeLookup("name", azureshared.Stor type storageTablesWrapper struct { client clients.TablesClient - *azureshared.ResourceGroupBase + *azureshared.MultiResourceGroupBase } -func NewStorageTable(client clients.TablesClient, subscriptionID, resourceGroup string) sources.SearchableWrapper { +func NewStorageTable(client clients.TablesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { return &storageTablesWrapper{ client: client, - ResourceGroupBase: azureshared.NewResourceGroupBase( - subscriptionID, - resourceGroup, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, sdp.AdapterCategory_ADAPTER_CATEGORY_STORAGE, azureshared.StorageTable, ), @@ -44,11 +43,11 @@ func (s storageTablesWrapper) Get(ctx context.Context, scope string, queryParts storageAccountName := queryParts[0] tableName := queryParts[1] - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - resp, err := s.client.Get(ctx, resourceGroup, storageAccountName, tableName) + resp, err := s.client.Get(ctx, rgScope.ResourceGroup, storageAccountName, tableName) if err != nil { return nil, azureshared.QueryError(err, scope, s.Type()) } @@ -123,11 +122,11 @@ func (s storageTablesWrapper) Search(ctx context.Context, scope string, queryPar return nil, azureshared.QueryError(fmt.Errorf("storageAccountName cannot be empty"), scope, s.Type()) } - resourceGroup := azureshared.ResourceGroupFromScope(scope) - if resourceGroup == "" { - resourceGroup = s.ResourceGroup() + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) } - pager := s.client.List(ctx, resourceGroup, storageAccountName) + pager := s.client.List(ctx, rgScope.ResourceGroup, storageAccountName) var items []*sdp.Item for pager.More() { diff --git a/sources/azure/manual/storage-table_test.go b/sources/azure/manual/storage-table_test.go index b0bc04c3..0ecc47cd 100644 --- a/sources/azure/manual/storage-table_test.go +++ b/sources/azure/manual/storage-table_test.go @@ -80,7 +80,7 @@ func TestStorageTables(t *testing.T) { }, nil) testClient := &testTablesClient{MockTablesClient: mockClient} - wrapper := manual.NewStorageTable(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageTable(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Get requires storageAccountName and tableName as query parts @@ -142,7 +142,7 @@ func TestStorageTables(t *testing.T) { mockClient := mocks.NewMockTablesClient(ctrl) testClient := &testTablesClient{MockTablesClient: mockClient} - wrapper := manual.NewStorageTable(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageTable(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) // Test with insufficient query parts (only storage account name) @@ -185,7 +185,7 @@ func TestStorageTables(t *testing.T) { pager: mockPager, } - wrapper := manual.NewStorageTable(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageTable(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -219,7 +219,7 @@ func TestStorageTables(t *testing.T) { mockClient := mocks.NewMockTablesClient(ctrl) testClient := &testTablesClient{MockTablesClient: mockClient} - wrapper := manual.NewStorageTable(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageTable(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Test Search directly with no query parts - should return error before calling List _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) @@ -257,7 +257,7 @@ func TestStorageTables(t *testing.T) { pager: mockPager, } - wrapper := manual.NewStorageTable(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageTable(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -289,7 +289,7 @@ func TestStorageTables(t *testing.T) { armstorage.TableClientGetResponse{}, expectedErr) testClient := &testTablesClient{MockTablesClient: mockClient} - wrapper := manual.NewStorageTable(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageTable(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) query := storageAccountName + shared.QuerySeparator + "nonexistent-table" @@ -309,7 +309,7 @@ func TestStorageTables(t *testing.T) { pager: errorPager, } - wrapper := manual.NewStorageTable(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageTable(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) searchable, ok := adapter.(discovery.SearchableAdapter) @@ -328,7 +328,7 @@ func TestStorageTables(t *testing.T) { t.Run("InterfaceCompliance", func(t *testing.T) { mockClient := mocks.NewMockTablesClient(ctrl) testClient := &testTablesClient{MockTablesClient: mockClient} - wrapper := manual.NewStorageTable(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageTable(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) // Verify wrapper implements SearchableWrapper (it's returned as this type) if wrapper == nil { @@ -346,7 +346,7 @@ func TestStorageTables(t *testing.T) { t.Run("PotentialLinks", func(t *testing.T) { mockClient := mocks.NewMockTablesClient(ctrl) testClient := &testTablesClient{MockTablesClient: mockClient} - wrapper := manual.NewStorageTable(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageTable(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) links := wrapper.PotentialLinks() if len(links) == 0 { @@ -361,7 +361,7 @@ func TestStorageTables(t *testing.T) { t.Run("TerraformMappings", func(t *testing.T) { mockClient := mocks.NewMockTablesClient(ctrl) testClient := &testTablesClient{MockTablesClient: mockClient} - wrapper := manual.NewStorageTable(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageTable(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) mappings := wrapper.TerraformMappings() if len(mappings) == 0 { @@ -392,7 +392,7 @@ func TestStorageTables(t *testing.T) { t.Run("IAMPermissions", func(t *testing.T) { mockClient := mocks.NewMockTablesClient(ctrl) testClient := &testTablesClient{MockTablesClient: mockClient} - wrapper := manual.NewStorageTable(testClient, subscriptionID, resourceGroup) + wrapper := manual.NewStorageTable(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) permissions := wrapper.IAMPermissions() if len(permissions) == 0 { diff --git a/sources/azure/proc/proc.go b/sources/azure/proc/proc.go index e8709e18..b0c78d2e 100644 --- a/sources/azure/proc/proc.go +++ b/sources/azure/proc/proc.go @@ -51,6 +51,7 @@ func init() { sdpcache.NewNoOpCache(), // no-op cache for metadata registration ) if err != nil { + // docs generation should fail if there are errors creating adapters panic(fmt.Errorf("error creating adapters: %w", err)) } @@ -61,12 +62,11 @@ func init() { log.Debug("Registered Azure source metadata", " with ", len(Metadata.AllAdapterMetadata()), " adapters") } -func Initialize(ctx context.Context, ec *discovery.EngineConfig, cfg *AzureConfig) (*discovery.Engine, error) { - engine, err := discovery.NewEngine(ec) - if err != nil { - return nil, fmt.Errorf("error initializing Engine: %w", err) - } - +// InitializeAdapters adds Azure adapters to an existing engine. This is a single-attempt +// function; retry logic is handled by the caller via Engine.InitialiseAdapters. +// +// cfg must not be nil — call ConfigFromViper() first for config validation. +func InitializeAdapters(ctx context.Context, engine *discovery.Engine, cfg *AzureConfig) error { // ReadinessCheck verifies adapters are healthy by using a StorageAccount adapter // Timeout is handled by SendHeartbeat, HTTP handlers rely on request context engine.SetReadinessCheck(func(ctx context.Context) error { @@ -95,106 +95,78 @@ func Initialize(ctx context.Context, ec *discovery.EngineConfig, cfg *AzureConfi // Create a shared cache for all adapters in this source sharedCache := sdpcache.NewCache(ctx) - err = func() error { - var logmsg string - // Use provided config, otherwise fall back to viper - if cfg != nil { - logmsg = "Using directly provided config" - } else { - var err error - cfg, err = readConfig() - if err != nil { - return fmt.Errorf("error creating config from command line: %w", err) - } - logmsg = "Using config from viper" - - } - log.WithFields(log.Fields{ - "ovm.source.type": "azure", - "ovm.source.subscription_id": cfg.SubscriptionID, - "ovm.source.tenant_id": cfg.TenantID, - "ovm.source.client_id": cfg.ClientID, - "ovm.source.regions": cfg.Regions, - }).Info(logmsg) - - // Regions are optional for Azure, but subscription ID is required - if cfg.SubscriptionID == "" { - return fmt.Errorf("Azure source must specify subscription ID") - } - - // Set Azure SDK environment variables from viper config if not already set. - // The Azure SDK's DefaultAzureCredential reads AZURE_CLIENT_ID and AZURE_TENANT_ID - // directly from environment variables for federated authentication. - // - // When using Azure Workload Identity webhook, these env vars are already injected - // by the webhook, so we only set them if they're not present. This supports both: - // 1. Azure Workload Identity webhook (env vars already injected) - // 2. Manual configuration (env vars set from viper config) - // - // Reference: https://azure.github.io/azure-workload-identity/docs/ - if os.Getenv("AZURE_CLIENT_ID") == "" && cfg.ClientID != "" { - os.Setenv("AZURE_CLIENT_ID", cfg.ClientID) - } - if os.Getenv("AZURE_TENANT_ID") == "" && cfg.TenantID != "" { - os.Setenv("AZURE_TENANT_ID", cfg.TenantID) - } - - // Initialize Azure credentials - cred, err := azureshared.NewAzureCredential(ctx) - if err != nil { - return fmt.Errorf("error creating Azure credentials: %w", err) - } - - // TODO: Implement linker when Azure dynamic adapters are available - var linker interface{} = nil + log.WithFields(log.Fields{ + "ovm.source.type": "azure", + "ovm.source.subscription_id": cfg.SubscriptionID, + "ovm.source.tenant_id": cfg.TenantID, + "ovm.source.client_id": cfg.ClientID, + "ovm.source.regions": cfg.Regions, + }).Info("Got config") + + // Regions are optional for Azure, but subscription ID is required + if cfg.SubscriptionID == "" { + return fmt.Errorf("Azure source must specify subscription ID") + } - discoveryAdapters, err := adapters(ctx, cfg.SubscriptionID, cfg.TenantID, cfg.ClientID, cfg.Regions, cred, linker, true, sharedCache) - if err != nil { - return fmt.Errorf("error creating discovery adapters: %w", err) - } + // Set Azure SDK environment variables from viper config if not already set. + // The Azure SDK's DefaultAzureCredential reads AZURE_CLIENT_ID and AZURE_TENANT_ID + // directly from environment variables for federated authentication. + // + // When using Azure Workload Identity webhook, these env vars are already injected + // by the webhook, so we only set them if they're not present. This supports both: + // 1. Azure Workload Identity webhook (env vars already injected) + // 2. Manual configuration (env vars set from viper config) + // + // Reference: https://azure.github.io/azure-workload-identity/docs/ + if os.Getenv("AZURE_CLIENT_ID") == "" && cfg.ClientID != "" { + os.Setenv("AZURE_CLIENT_ID", cfg.ClientID) + } + if os.Getenv("AZURE_TENANT_ID") == "" && cfg.TenantID != "" { + os.Setenv("AZURE_TENANT_ID", cfg.TenantID) + } - // Verify subscription access before adding adapters - err = checkSubscriptionAccess(ctx, cfg.SubscriptionID, cred) - if err != nil { - log.WithContext(ctx).WithError(err).WithFields(log.Fields{ - "ovm.source.type": "azure", - "ovm.source.subscription_id": cfg.SubscriptionID, - }).Error("Permission check failed for subscription") - } else { - log.WithContext(ctx).WithFields(log.Fields{ - "ovm.source.type": "azure", - "ovm.source.subscription_id": cfg.SubscriptionID, - }).Info("Permission check passed for subscription") - } + // Initialize Azure credentials + cred, err := azureshared.NewAzureCredential(ctx) + if err != nil { + return fmt.Errorf("error creating Azure credentials: %w", err) + } - // Add the adapters to the engine - err = engine.AddAdapters(discoveryAdapters...) - if err != nil { - return fmt.Errorf("error adding adapters to engine: %w", err) - } + // TODO: Implement linker when Azure dynamic adapters are available + var linker interface{} = nil - return nil - }() + discoveryAdapters, err := adapters(ctx, cfg.SubscriptionID, cfg.TenantID, cfg.ClientID, cfg.Regions, cred, linker, true, sharedCache) + if err != nil { + return fmt.Errorf("error creating discovery adapters: %w", err) + } + // Verify subscription access before adding adapters + err = checkSubscriptionAccess(ctx, cfg.SubscriptionID, cred) if err != nil { - log.WithError(err).Debug("Error initializing Azure source") - return nil, fmt.Errorf("error initializing Azure source: %w", err) + log.WithContext(ctx).WithError(err).WithFields(log.Fields{ + "ovm.source.type": "azure", + "ovm.source.subscription_id": cfg.SubscriptionID, + }).Error("Permission check failed for subscription") + } else { + log.WithContext(ctx).WithFields(log.Fields{ + "ovm.source.type": "azure", + "ovm.source.subscription_id": cfg.SubscriptionID, + }).Info("Permission check passed for subscription") } - // Start sending heartbeats after adapters are successfully added - // This ensures the first heartbeat has adapters available for readiness checks - engine.StartSendingHeartbeats(ctx) - brokenHeart := engine.SendHeartbeat(ctx, nil) // Send the error immediately through the custom health check func - if brokenHeart != nil { - log.WithError(brokenHeart).Error("Error sending heartbeat") + // Add the adapters to the engine + err = engine.AddAdapters(discoveryAdapters...) + if err != nil { + return fmt.Errorf("error adding adapters to engine: %w", err) } log.Debug("Sources initialized") - // If there is no error then return the engine - return engine, nil + return nil } -func readConfig() (*AzureConfig, error) { +// ConfigFromViper reads and validates the Azure configuration from viper flags. +// This performs local validation only (no API calls) and should be called +// before InitializeAdapters to catch permanent config errors early. +func ConfigFromViper() (*AzureConfig, error) { subscriptionID := viper.GetString("azure-subscription-id") if subscriptionID == "" { return nil, fmt.Errorf("azure-subscription-id not set") diff --git a/sources/azure/shared/mocks/mock_storage_accounts_client.go b/sources/azure/shared/mocks/mock_storage_accounts_client.go index 125cc808..5cca190f 100644 --- a/sources/azure/shared/mocks/mock_storage_accounts_client.go +++ b/sources/azure/shared/mocks/mock_storage_accounts_client.go @@ -57,16 +57,16 @@ func (mr *MockStorageAccountsClientMockRecorder) Get(ctx, resourceGroupName, acc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockStorageAccountsClient)(nil).Get), ctx, resourceGroupName, accountName) } -// List mocks base method. -func (m *MockStorageAccountsClient) List(resourceGroupName string) clients.StorageAccountsPager { +// NewListByResourceGroupPager mocks base method. +func (m *MockStorageAccountsClient) NewListByResourceGroupPager(resourceGroupName string, options *armstorage.AccountsClientListByResourceGroupOptions) clients.StorageAccountsPager { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "List", resourceGroupName) + ret := m.ctrl.Call(m, "NewListByResourceGroupPager", resourceGroupName, options) ret0, _ := ret[0].(clients.StorageAccountsPager) return ret0 } -// List indicates an expected call of List. -func (mr *MockStorageAccountsClientMockRecorder) List(resourceGroupName any) *gomock.Call { +// NewListByResourceGroupPager indicates an expected call of NewListByResourceGroupPager. +func (mr *MockStorageAccountsClientMockRecorder) NewListByResourceGroupPager(resourceGroupName, options any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockStorageAccountsClient)(nil).List), resourceGroupName) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListByResourceGroupPager", reflect.TypeOf((*MockStorageAccountsClient)(nil).NewListByResourceGroupPager), resourceGroupName, options) } diff --git a/sources/azure/shared/scope.go b/sources/azure/shared/scope.go new file mode 100644 index 00000000..3c0a9281 --- /dev/null +++ b/sources/azure/shared/scope.go @@ -0,0 +1,84 @@ +package shared + +import ( + "fmt" + + "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sources/shared" +) + +// ResourceGroupScope represents a subscription and resource group pair. +// It is used by multi-scope adapters to handle multiple resource groups. +type ResourceGroupScope struct { + SubscriptionID string + ResourceGroup string +} + +// NewResourceGroupScope creates a ResourceGroupScope for the given subscription and resource group. +func NewResourceGroupScope(subscriptionID, resourceGroup string) ResourceGroupScope { + return ResourceGroupScope{ + SubscriptionID: subscriptionID, + ResourceGroup: resourceGroup, + } +} + +// ToScope returns the scope string in format "{subscriptionId}.{resourceGroup}". +func (r ResourceGroupScope) ToScope() string { + return fmt.Sprintf("%s.%s", r.SubscriptionID, r.ResourceGroup) +} + +// MultiResourceGroupBase provides shared multi-scope behavior for resource-group-scoped adapters. +// One adapter instance handles all resource groups in resourceGroupScopes. +type MultiResourceGroupBase struct { + resourceGroupScopes []ResourceGroupScope + *shared.Base +} + +// NewMultiResourceGroupBase creates a MultiResourceGroupBase that supports multiple resource group scopes. +func NewMultiResourceGroupBase( + resourceGroupScopes []ResourceGroupScope, + category sdp.AdapterCategory, + item shared.ItemType, +) *MultiResourceGroupBase { + if len(resourceGroupScopes) == 0 { + panic("NewMultiResourceGroupBase: resourceGroupScopes cannot be empty") + } + + scopeStrings := make([]string, 0, len(resourceGroupScopes)) + for _, rgScope := range resourceGroupScopes { + scopeStrings = append(scopeStrings, rgScope.ToScope()) + } + + return &MultiResourceGroupBase{ + resourceGroupScopes: resourceGroupScopes, + Base: shared.NewBase(category, item, scopeStrings), + } +} + +// ResourceGroupScopeFromScope parses a scope string and returns the matching ResourceGroupScope +// if it is one of the adapter's configured scopes. +func (m *MultiResourceGroupBase) ResourceGroupScopeFromScope(scope string) (ResourceGroupScope, error) { + subscriptionID := SubscriptionIDFromScope(scope) + resourceGroup := ResourceGroupFromScope(scope) + if subscriptionID == "" || resourceGroup == "" { + return ResourceGroupScope{}, fmt.Errorf("invalid scope format %q: expected subscriptionId.resourceGroup", scope) + } + + rgScope := NewResourceGroupScope(subscriptionID, resourceGroup) + for _, s := range m.resourceGroupScopes { + if s.SubscriptionID == rgScope.SubscriptionID && s.ResourceGroup == rgScope.ResourceGroup { + return rgScope, nil + } + } + return ResourceGroupScope{}, fmt.Errorf("scope %s not found in adapter resource group scopes", scope) +} + +// ResourceGroupScopes returns the configured resource group scopes for this adapter. +func (m *MultiResourceGroupBase) ResourceGroupScopes() []ResourceGroupScope { + return m.resourceGroupScopes +} + +// DefaultScope returns the first scope (for compatibility where a single default is needed). +func (m *MultiResourceGroupBase) DefaultScope() string { + return m.Scopes()[0] +} diff --git a/sources/gcp/cmd/root.go b/sources/gcp/cmd/root.go index 10ddbfa2..30cb4aae 100644 --- a/sources/gcp/cmd/root.go +++ b/sources/gcp/cmd/root.go @@ -24,63 +24,67 @@ var cfgFile string // rootCmd represents the base command when called without any subcommands var rootCmd = &cobra.Command{ - Use: "gcp-source", - Short: "Remote primary source for GCP", + Use: "gcp-source", + Short: "Remote primary source for GCP", + SilenceUsage: true, Long: `This sources looks for GCP resources in your account. `, - Run: func(cmd *cobra.Command, args []string) { - ctx := context.Background() + RunE: func(cmd *cobra.Command, args []string) error { + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() defer tracing.LogRecoverToReturn(ctx, "gcp-source.root") healthCheckPort := viper.GetInt("health-check-port") engineConfig, err := discovery.EngineConfigFromViper("gcp", tracing.Version()) if err != nil { - log.WithError(err).Fatal("Could not create engine config") + log.WithError(err).Error("Could not create engine config") + return fmt.Errorf("could not create engine config: %w", err) } - err = engineConfig.CreateClients() + // Create a basic engine first so we can serve health probes and heartbeats even if init fails + e, err := discovery.NewEngine(engineConfig) if err != nil { sentry.CaptureException(err) - log.WithError(err).Fatal("could not auth create clients") + log.WithError(err).Error("Could not create engine") + return fmt.Errorf("could not create engine: %w", err) } - e, err := proc.Initialize(ctx, engineConfig, nil) - if err != nil { - log.WithError(err).Fatal("Could not initialize GCP source") - } - - e.StartSendingHeartbeats(ctx) - + // Serve health probes before initialization so they're available even on failure e.ServeHealthProbes(healthCheckPort) + // Start the engine (NATS connection) before adapter init so heartbeats work err = e.Start(ctx) if err != nil { - log.WithFields(log.Fields{ - "ovm.source.type": "gcp", - "ovm.source.error": err, - }).Fatal("Could not start engine") + sentry.CaptureException(err) + log.WithError(err).Error("Could not start engine") + return fmt.Errorf("could not start engine: %w", err) } - sigs := make(chan os.Signal, 1) - - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + // Config validation (permanent errors — no retry, just idle with error) + gcpCfg, cfgErr := proc.ConfigFromViper() + if cfgErr != nil { + log.WithError(cfgErr).Error("GCP source config error - pod will stay running with error status") + e.SetInitError(cfgErr) + sentry.CaptureException(cfgErr) + } else { + // Adapter init (retryable errors — backoff capped at 5 min) + e.InitialiseAdapters(ctx, func(ctx context.Context) error { + return proc.InitializeAdapters(ctx, e, gcpCfg) + }) + } - <-sigs + <-ctx.Done() log.Info("Stopping engine") err = e.Stop() if err != nil { - log.WithFields(log.Fields{ - "ovm.source.type": "gcp", - "ovm.source.error": err, - }).Error("Could not stop engine") - - os.Exit(1) + log.WithError(err).Error("Could not stop engine") + return fmt.Errorf("could not stop engine: %w", err) } log.Info("Stopped") - os.Exit(0) + return nil }, } @@ -127,7 +131,7 @@ func init() { cobra.CheckErr(viper.BindPFlags(rootCmd.PersistentFlags())) // Run this before we do anything to set up the loglevel - rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { if lvl, err := log.ParseLevel(logLevel); err == nil { log.SetLevel(lvl) } else { @@ -140,23 +144,29 @@ func init() { log.AddHook(TerminationLogHook{}) // Bind flags that haven't been set to the values from viper of we have them + var bindErr error cmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { // Bind the flag to viper only if it has a non-empty default if f.DefValue != "" || f.Changed { - err := viper.BindPFlag(f.Name, f) - if err != nil { - log.WithError(err).Fatal("could not bind flag to viper") + if err := viper.BindPFlag(f.Name, f); err != nil { + bindErr = err } } }) + if bindErr != nil { + log.WithError(bindErr).Error("could not bind flag to viper") + return fmt.Errorf("could not bind flag to viper: %w", bindErr) + } if viper.GetBool("json-log") { logging.ConfigureLogrusJSON(log.StandardLogger()) } if err := tracing.InitTracerWithUpstreams("gcp-source", viper.GetString("honeycomb-api-key"), viper.GetString("sentry-dsn")); err != nil { - log.Fatal(err) + log.WithError(err).Error("could not init tracer") + return fmt.Errorf("could not init tracer: %w", err) } + return nil } // shut down tracing at the end of the process rootCmd.PersistentPostRun = func(cmd *cobra.Command, args []string) { diff --git a/sources/gcp/dynamic/adapter-listable.go b/sources/gcp/dynamic/adapter-listable.go index ad485aa8..5a4d9fda 100644 --- a/sources/gcp/dynamic/adapter-listable.go +++ b/sources/gcp/dynamic/adapter-listable.go @@ -26,7 +26,7 @@ func NewListableAdapter(listEndpointFunc gcpshared.ListEndpointFunc, config *Ada Adapter: Adapter{ locations: config.Locations, httpCli: config.HTTPClient, - Cache: cache, + cache: cache, getURLFunc: config.GetURLFunc, sdpAssetType: config.SDPAssetType, sdpAdapterCategory: config.SDPAdapterCategory, @@ -63,7 +63,7 @@ func (g ListableAdapter) List(ctx context.Context, scope string, ignoreCache boo return nil, err } - cacheHit, ck, cachedItems, qErr, done := g.GetCache().Lookup( + cacheHit, ck, cachedItems, qErr, done := g.cache.Lookup( ctx, g.Name(), sdp.QueryMethod_LIST, @@ -94,18 +94,18 @@ func (g ListableAdapter) List(ctx context.Context, scope string, ignoreCache boo ErrorType: sdp.QueryError_OTHER, ErrorString: fmt.Sprintf("failed to construct list endpoint: %v", err), } - g.GetCache().StoreError(ctx, err, shared.DefaultCacheDuration, ck) + g.cache.StoreError(ctx, err, shared.DefaultCacheDuration, ck) return nil, err } items, err := aggregateSDPItems(ctx, g.Adapter, listURL, location) if err != nil { - g.GetCache().StoreError(ctx, err, shared.DefaultCacheDuration, ck) + g.cache.StoreError(ctx, err, shared.DefaultCacheDuration, ck) return nil, err } for _, item := range items { - g.GetCache().StoreItem(ctx, item, shared.DefaultCacheDuration, ck) + g.cache.StoreItem(ctx, item, shared.DefaultCacheDuration, ck) } return items, nil @@ -118,7 +118,7 @@ func (g ListableAdapter) ListStream(ctx context.Context, scope string, ignoreCac return } - cacheHit, ck, cachedItems, qErr, done := g.GetCache().Lookup( + cacheHit, ck, cachedItems, qErr, done := g.cache.Lookup( ctx, g.Name(), sdp.QueryMethod_LIST, @@ -156,5 +156,5 @@ func (g ListableAdapter) ListStream(ctx context.Context, scope string, ignoreCac return } - streamSDPItems(ctx, g.Adapter, listURL, location, stream, g.GetCache(), ck) + streamSDPItems(ctx, g.Adapter, listURL, location, stream, g.cache, ck) } diff --git a/sources/gcp/dynamic/adapter-searchable-listable.go b/sources/gcp/dynamic/adapter-searchable-listable.go index b8827bab..34d39f1f 100644 --- a/sources/gcp/dynamic/adapter-searchable-listable.go +++ b/sources/gcp/dynamic/adapter-searchable-listable.go @@ -36,7 +36,7 @@ func NewSearchableListableAdapter(searchURLFunc gcpshared.EndpointFunc, listEndp Adapter: Adapter{ locations: config.Locations, httpCli: config.HTTPClient, - Cache: cache, + cache: cache, getURLFunc: config.GetURLFunc, sdpAssetType: config.SDPAssetType, sdpAdapterCategory: config.SDPAdapterCategory, @@ -76,7 +76,7 @@ func (g SearchableListableAdapter) Search(ctx context.Context, scope, query stri return nil, err } - cacheHit, ck, cachedItems, qErr, done := g.GetCache().Lookup( + cacheHit, ck, cachedItems, qErr, done := g.cache.Lookup( ctx, g.Name(), sdp.QueryMethod_SEARCH, @@ -105,7 +105,7 @@ func (g SearchableListableAdapter) Search(ctx context.Context, scope, query stri // This must be a terraform query in the format of: // projects/{{project}}/datasets/{{dataset}}/tables/{{name}} // projects/{{project}}/serviceAccounts/{{account}}/keys/{{key}} - return terraformMappingViaSearch(ctx, g.Adapter, query, location, g.GetCache(), ck) + return terraformMappingViaSearch(ctx, g.Adapter, query, location, g.cache, ck) } searchEndpoint := g.searchEndpointFunc(query, location) @@ -114,18 +114,18 @@ func (g SearchableListableAdapter) Search(ctx context.Context, scope, query stri ErrorType: sdp.QueryError_OTHER, ErrorString: fmt.Sprintf("no search endpoint found for query \"%s\". %s", query, g.Metadata().GetSupportedQueryMethods().GetSearchDescription()), } - g.GetCache().StoreError(ctx, err, shared.DefaultCacheDuration, ck) + g.cache.StoreError(ctx, err, shared.DefaultCacheDuration, ck) return nil, err } items, err := aggregateSDPItems(ctx, g.Adapter, searchEndpoint, location) if err != nil { - g.GetCache().StoreError(ctx, err, shared.DefaultCacheDuration, ck) + g.cache.StoreError(ctx, err, shared.DefaultCacheDuration, ck) return nil, err } for _, item := range items { - g.GetCache().StoreItem(ctx, item, shared.DefaultCacheDuration, ck) + g.cache.StoreItem(ctx, item, shared.DefaultCacheDuration, ck) } return items, nil @@ -138,7 +138,7 @@ func (g SearchableListableAdapter) SearchStream(ctx context.Context, scope, quer return } - cacheHit, ck, cachedItems, qErr, done := g.GetCache().Lookup( + cacheHit, ck, cachedItems, qErr, done := g.cache.Lookup( ctx, g.Name(), sdp.QueryMethod_SEARCH, @@ -171,7 +171,7 @@ func (g SearchableListableAdapter) SearchStream(ctx context.Context, scope, quer // This must be a terraform query in the format of: // projects/{{project}}/datasets/{{dataset}}/tables/{{name}} // projects/{{project}}/serviceAccounts/{{account}}/keys/{{key}} - items, err := terraformMappingViaSearch(ctx, g.Adapter, query, location, g.GetCache(), ck) + items, err := terraformMappingViaSearch(ctx, g.Adapter, query, location, g.cache, ck) if err != nil { stream.SendError(&sdp.QueryError{ ErrorType: sdp.QueryError_OTHER, @@ -179,7 +179,7 @@ func (g SearchableListableAdapter) SearchStream(ctx context.Context, scope, quer }) return } - g.GetCache().StoreItem(ctx, items[0], shared.DefaultCacheDuration, ck) + g.cache.StoreItem(ctx, items[0], shared.DefaultCacheDuration, ck) // There should only be one item in the result, so we can send it directly stream.SendItem(items[0]) @@ -199,5 +199,5 @@ func (g SearchableListableAdapter) SearchStream(ctx context.Context, scope, quer return } - streamSDPItems(ctx, g.Adapter, searchURL, location, stream, g.GetCache(), ck) + streamSDPItems(ctx, g.Adapter, searchURL, location, stream, g.cache, ck) } diff --git a/sources/gcp/dynamic/adapter-searchable.go b/sources/gcp/dynamic/adapter-searchable.go index 6e78f549..b3f009eb 100644 --- a/sources/gcp/dynamic/adapter-searchable.go +++ b/sources/gcp/dynamic/adapter-searchable.go @@ -30,7 +30,7 @@ func NewSearchableAdapter(searchEndpointFunc gcpshared.EndpointFunc, config *Ada Adapter: Adapter{ locations: config.Locations, httpCli: config.HTTPClient, - Cache: cache, + cache: cache, getURLFunc: config.GetURLFunc, sdpAssetType: config.SDPAssetType, sdpAdapterCategory: config.SDPAdapterCategory, @@ -67,7 +67,7 @@ func (g SearchableAdapter) Search(ctx context.Context, scope, query string, igno return nil, err } - cacheHit, ck, cachedItems, qErr, done := g.GetCache().Lookup( + cacheHit, ck, cachedItems, qErr, done := g.cache.Lookup( ctx, g.Name(), sdp.QueryMethod_SEARCH, @@ -95,7 +95,7 @@ func (g SearchableAdapter) Search(ctx context.Context, scope, query string, igno // This must be a terraform query in the format of: // projects/{{project}}/datasets/{{dataset}}/tables/{{name}} // projects/{{project}}/serviceAccounts/{{account}}/keys/{{key}} - return terraformMappingViaSearch(ctx, g.Adapter, query, location, g.GetCache(), ck) + return terraformMappingViaSearch(ctx, g.Adapter, query, location, g.cache, ck) } // This is a regular SEARCH call @@ -105,18 +105,18 @@ func (g SearchableAdapter) Search(ctx context.Context, scope, query string, igno ErrorType: sdp.QueryError_OTHER, ErrorString: fmt.Sprintf("no search endpoint found for query \"%s\". %s", query, g.Metadata().GetSupportedQueryMethods().GetSearchDescription()), } - g.GetCache().StoreError(ctx, err, shared.DefaultCacheDuration, ck) + g.cache.StoreError(ctx, err, shared.DefaultCacheDuration, ck) return nil, err } items, err := aggregateSDPItems(ctx, g.Adapter, searchEndpoint, location) if err != nil { - g.GetCache().StoreError(ctx, err, shared.DefaultCacheDuration, ck) + g.cache.StoreError(ctx, err, shared.DefaultCacheDuration, ck) return nil, err } for _, item := range items { - g.GetCache().StoreItem(ctx, item, shared.DefaultCacheDuration, ck) + g.cache.StoreItem(ctx, item, shared.DefaultCacheDuration, ck) } return items, nil @@ -129,7 +129,7 @@ func (g SearchableAdapter) SearchStream(ctx context.Context, scope, query string return } - cacheHit, ck, cachedItems, qErr, done := g.GetCache().Lookup( + cacheHit, ck, cachedItems, qErr, done := g.cache.Lookup( ctx, g.Name(), sdp.QueryMethod_SEARCH, @@ -161,7 +161,7 @@ func (g SearchableAdapter) SearchStream(ctx context.Context, scope, query string // This must be a terraform query in the format of: // projects/{{project}}/datasets/{{dataset}}/tables/{{name}} // projects/{{project}}/serviceAccounts/{{account}}/keys/{{key}} - items, err := terraformMappingViaSearch(ctx, g.Adapter, query, location, g.GetCache(), ck) + items, err := terraformMappingViaSearch(ctx, g.Adapter, query, location, g.cache, ck) if err != nil { stream.SendError(&sdp.QueryError{ ErrorType: sdp.QueryError_OTHER, @@ -169,7 +169,7 @@ func (g SearchableAdapter) SearchStream(ctx context.Context, scope, query string }) return } - g.GetCache().StoreItem(ctx, items[0], shared.DefaultCacheDuration, ck) + g.cache.StoreItem(ctx, items[0], shared.DefaultCacheDuration, ck) // There should only be one item in the result, so we can send it directly stream.SendItem(items[0]) @@ -189,5 +189,5 @@ func (g SearchableAdapter) SearchStream(ctx context.Context, scope, query string return } - streamSDPItems(ctx, g.Adapter, searchURL, location, stream, g.GetCache(), ck) + streamSDPItems(ctx, g.Adapter, searchURL, location, stream, g.cache, ck) } diff --git a/sources/gcp/dynamic/adapter.go b/sources/gcp/dynamic/adapter.go index eb6a4ac7..f04c9674 100644 --- a/sources/gcp/dynamic/adapter.go +++ b/sources/gcp/dynamic/adapter.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net/http" - "sync" "buf.build/go/protovalidate" log "github.com/sirupsen/logrus" @@ -35,7 +34,7 @@ type AdapterConfig struct { type Adapter struct { locations []gcpshared.LocationInfo httpCli *http.Client - Cache sdpcache.Cache + cache sdpcache.Cache getURLFunc gcpshared.EndpointFunc sdpAssetType shared.ItemType sdpAdapterCategory sdp.AdapterCategory @@ -53,7 +52,7 @@ func NewAdapter(config *AdapterConfig, cache sdpcache.Cache) discovery.Adapter { return Adapter{ locations: config.Locations, httpCli: config.HTTPClient, - Cache: cache, + cache: cache, getURLFunc: config.GetURLFunc, sdpAssetType: config.SDPAssetType, sdpAdapterCategory: config.SDPAdapterCategory, @@ -93,21 +92,6 @@ func (g Adapter) Scopes() []string { return gcpshared.LocationsToScopes(g.locations) } -var ( - noOpCacheGCPOnce sync.Once - noOpCacheGCP sdpcache.Cache -) - -func (g Adapter) GetCache() sdpcache.Cache { - if g.Cache == nil { - noOpCacheGCPOnce.Do(func() { - noOpCacheGCP = sdpcache.NewNoOpCache() - }) - return noOpCacheGCP - } - return g.Cache -} - // validateScope checks if the requested scope matches one of the adapter's locations. func (g Adapter) validateScope(scope string) (gcpshared.LocationInfo, error) { requestedLoc, err := gcpshared.LocationFromScope(scope) @@ -135,7 +119,7 @@ func (g Adapter) Get(ctx context.Context, scope string, query string, ignoreCach return nil, err } - cacheHit, ck, cachedItem, qErr, done := g.GetCache().Lookup( + cacheHit, ck, cachedItem, qErr, done := g.cache.Lookup( ctx, g.Name(), sdp.QueryMethod_GET, @@ -169,23 +153,23 @@ func (g Adapter) Get(ctx context.Context, scope string, query string, ignoreCach g.Metadata().GetSupportedQueryMethods().GetGetDescription(), ), } - g.GetCache().StoreError(ctx, err, shared.DefaultCacheDuration, ck) + g.cache.StoreError(ctx, err, shared.DefaultCacheDuration, ck) return nil, err } resp, err := externalCallSingle(ctx, g.httpCli, url) if err != nil { - g.GetCache().StoreError(ctx, err, shared.DefaultCacheDuration, ck) + g.cache.StoreError(ctx, err, shared.DefaultCacheDuration, ck) return nil, err } item, err := externalToSDP(ctx, location, g.uniqueAttributeKeys, resp, g.sdpAssetType, g.linker, g.nameSelector) if err != nil { - g.GetCache().StoreError(ctx, err, shared.DefaultCacheDuration, ck) + g.cache.StoreError(ctx, err, shared.DefaultCacheDuration, ck) return nil, err } - g.GetCache().StoreItem(ctx, item, shared.DefaultCacheDuration, ck) + g.cache.StoreItem(ctx, item, shared.DefaultCacheDuration, ck) return item, nil } diff --git a/sources/gcp/manual/big-query-routine.go b/sources/gcp/manual/big-query-routine.go index 18c880af..3548b49f 100644 --- a/sources/gcp/manual/big-query-routine.go +++ b/sources/gcp/manual/big-query-routine.go @@ -57,10 +57,12 @@ func (b BigQueryRoutineWrapper) PotentialLinks() map[shared.ItemType]bool { func (b BigQueryRoutineWrapper) TerraformMappings() []*sdp.TerraformMapping { return []*sdp.TerraformMapping{ { - TerraformMethod: sdp.QueryMethod_GET, + TerraformMethod: sdp.QueryMethod_SEARCH, // https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/bigquery_routine - // projects/{{project}}/datasets/{{dataset_id}}/routines/{{routine_id}} - TerraformQueryMap: "google_bigquery_routine.routine_id", + // ID format: projects/{{project}}/datasets/{{dataset_id}}/routines/{{routine_id}} + // The framework automatically intercepts queries starting with "projects/" and converts + // them to GET operations by extracting the last N path parameters (based on GetLookups count). + TerraformQueryMap: "google_bigquery_routine.id", }, } } @@ -121,7 +123,32 @@ func (b BigQueryRoutineWrapper) Search(ctx context.Context, scope string, queryP } func (b BigQueryRoutineWrapper) SearchStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string, queryParts ...string) { - // SearchStream not implemented for BigQueryRoutine + location, err := b.LocationFromScope(scope) + if err != nil { + stream.SendError(&sdp.QueryError{ + ErrorType: sdp.QueryError_NOSCOPE, + ErrorString: err.Error(), + }) + return + } + + toItem := func(metadata *bigquery.RoutineMetadata, datasetID, routineID string) (*sdp.Item, *sdp.QueryError) { + item, qerr := b.gcpBigQueryRoutineToItem(metadata, datasetID, routineID, location) + if qerr == nil && item != nil { + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + } + return item, qerr + } + + items, listErr := b.client.List(ctx, location.ProjectID, queryParts[0], toItem) + if listErr != nil { + stream.SendError(gcpshared.QueryError(listErr, scope, b.Type())) + return + } + + for _, item := range items { + stream.SendItem(item) + } } func (b BigQueryRoutineWrapper) gcpBigQueryRoutineToItem(metadata *bigquery.RoutineMetadata, datasetID, routineID string, location gcpshared.LocationInfo) (*sdp.Item, *sdp.QueryError) { diff --git a/sources/gcp/manual/big-query-routine_test.go b/sources/gcp/manual/big-query-routine_test.go index 26a9e543..8aa325c8 100644 --- a/sources/gcp/manual/big-query-routine_test.go +++ b/sources/gcp/manual/big-query-routine_test.go @@ -183,6 +183,71 @@ func TestBigQueryRoutine(t *testing.T) { t.Fatalf("Expected error, got nil") } }) + + t.Run("Search with terraform format", func(t *testing.T) { + wrapper := manual.NewBigQueryRoutine(mockClient, []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + // Use terraform-style path format + terraformStyleQuery := "projects/test-project/datasets/test_dataset/routines/test_routine" + + // Mock Get (called internally when terraform format is detected) + mockClient.EXPECT().Get(ctx, projectID, datasetID, routineID).Return(createRoutineMetadata("terraform format test"), nil) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + items, qErr := searchable.Search(ctx, wrapper.Scopes()[0], terraformStyleQuery, true) + if qErr != nil { + t.Fatalf("Expected no error with terraform format, got: %v", qErr) + } + if len(items) != 1 { + t.Fatalf("Expected 1 item, got: %d", len(items)) + } + if items[0].GetType() != gcpshared.BigQueryRoutine.String() { + t.Fatalf("Expected type %s, got: %s", gcpshared.BigQueryRoutine.String(), items[0].GetType()) + } + }) + + t.Run("Search with legacy pipe format", func(t *testing.T) { + wrapper := manual.NewBigQueryRoutine(mockClient, []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + // Use legacy dataset ID format + legacyQuery := datasetID + + // Mock the List function + mockClient.EXPECT().List( + gomock.Any(), + projectID, + datasetID, + gomock.Any(), + ).DoAndReturn(func(ctx context.Context, projectID string, datasetID string, converter func(routine *bigquery.RoutineMetadata, datasetID, routineID string) (*sdp.Item, *sdp.QueryError)) ([]*sdp.Item, *sdp.QueryError) { + items := make([]*sdp.Item, 0, 1) + routine := createRoutineMetadata("legacy format test") + item, qErr := converter(routine, datasetID, routineID) + if qErr != nil { + return nil, qErr + } + items = append(items, item) + return items, nil + }) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + items, qErr := searchable.Search(ctx, wrapper.Scopes()[0], legacyQuery, true) + if qErr != nil { + t.Fatalf("Expected no error with legacy format, got: %v", qErr) + } + if len(items) != 1 { + t.Fatalf("Expected 1 item, got: %d", len(items)) + } + }) } func createRoutineMetadata(description string) *bigquery.RoutineMetadata { diff --git a/sources/gcp/manual/big-query-table.go b/sources/gcp/manual/big-query-table.go index c4d5739e..1f9f9475 100644 --- a/sources/gcp/manual/big-query-table.go +++ b/sources/gcp/manual/big-query-table.go @@ -61,9 +61,11 @@ func (b BigQueryTableWrapper) PotentialLinks() map[shared.ItemType]bool { func (b BigQueryTableWrapper) TerraformMappings() []*sdp.TerraformMapping { return []*sdp.TerraformMapping{ { - TerraformMethod: sdp.QueryMethod_GET, + TerraformMethod: sdp.QueryMethod_SEARCH, // https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/bigquery_table - // projects/{{project}}/datasets/{{dataset}}/tables/{{name}} + // ID format: projects/{{project}}/datasets/{{dataset}}/tables/{{name}} + // The framework automatically intercepts queries starting with "projects/" and converts + // them to GET operations by extracting the last N path parameters (based on GetLookups count). TerraformQueryMap: "google_bigquery_table.id", }, } diff --git a/sources/gcp/manual/cloud-kms-crypto-key-version.go b/sources/gcp/manual/cloud-kms-crypto-key-version.go index bfd7e84f..302d2fdc 100644 --- a/sources/gcp/manual/cloud-kms-crypto-key-version.go +++ b/sources/gcp/manual/cloud-kms-crypto-key-version.go @@ -55,7 +55,11 @@ func (c cloudKMSCryptoKeyVersionWrapper) PotentialLinks() map[shared.ItemType]bo func (c cloudKMSCryptoKeyVersionWrapper) TerraformMappings() []*sdp.TerraformMapping { return []*sdp.TerraformMapping{ { - TerraformMethod: sdp.QueryMethod_GET, + TerraformMethod: sdp.QueryMethod_SEARCH, + // https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/kms_crypto_key_version + // ID format: projects/{project}/locations/{location}/keyRings/{keyRing}/cryptoKeys/{cryptoKey}/cryptoKeyVersions/{version} + // The framework automatically intercepts queries starting with "projects/" and converts + // them to GET operations by extracting the last N path parameters (based on GetLookups count). TerraformQueryMap: "google_kms_crypto_key_version.id", }, } diff --git a/sources/gcp/manual/cloud-kms-crypto-key-version_test.go b/sources/gcp/manual/cloud-kms-crypto-key-version_test.go index bb8821e4..9a48356f 100644 --- a/sources/gcp/manual/cloud-kms-crypto-key-version_test.go +++ b/sources/gcp/manual/cloud-kms-crypto-key-version_test.go @@ -19,7 +19,7 @@ func TestCloudKMSCryptoKeyVersion(t *testing.T) { projectID := "test-project-id" t.Run("Get_CacheHit", func(t *testing.T) { - cache := sdpcache.NewCache(ctx) + cache := sdpcache.NewMemoryCache() defer cache.Clear() // Pre-populate cache with a CryptoKeyVersion item @@ -70,7 +70,7 @@ func TestCloudKMSCryptoKeyVersion(t *testing.T) { }) t.Run("Get_CacheMiss_NotFound", func(t *testing.T) { - cache := sdpcache.NewCache(ctx) + cache := sdpcache.NewMemoryCache() defer cache.Clear() // Pre-populate cache with a NOTFOUND error to simulate item not existing @@ -101,7 +101,7 @@ func TestCloudKMSCryptoKeyVersion(t *testing.T) { }) t.Run("Search_CacheHit", func(t *testing.T) { - cache := sdpcache.NewCache(ctx) + cache := sdpcache.NewMemoryCache() defer cache.Clear() // Pre-populate cache with CryptoKeyVersion items under SEARCH cache key (by cryptoKey) @@ -158,7 +158,7 @@ func TestCloudKMSCryptoKeyVersion(t *testing.T) { }) t.Run("Search_CacheHit_Empty", func(t *testing.T) { - cache := sdpcache.NewCache(ctx) + cache := sdpcache.NewMemoryCache() defer cache.Clear() // Store NOTFOUND error in cache to simulate empty result @@ -191,7 +191,7 @@ func TestCloudKMSCryptoKeyVersion(t *testing.T) { }) t.Run("List_Unsupported", func(t *testing.T) { - cache := sdpcache.NewCache(ctx) + cache := sdpcache.NewMemoryCache() defer cache.Clear() loader := gcpshared.NewCloudKMSAssetLoader(nil, projectID, cache, "gcp-source", []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) @@ -206,10 +206,130 @@ func TestCloudKMSCryptoKeyVersion(t *testing.T) { } }) - t.Run("StaticTests", func(t *testing.T) { + t.Run("Search_TerraformFormat", func(t *testing.T) { cache := sdpcache.NewCache(ctx) defer cache.Clear() + // Pre-populate cache with CryptoKeyVersion items under SEARCH cache key (by cryptoKey) + attrs1, _ := sdp.ToAttributesViaJson(map[string]interface{}{ + "name": "projects/test-project-id/locations/us-central1/keyRings/my-keyring/cryptoKeys/my-key/cryptoKeyVersions/1", + "uniqueAttr": "us-central1|my-keyring|my-key|1", + }) + _ = attrs1.Set("uniqueAttr", "us-central1|my-keyring|my-key|1") + + attrs2, _ := sdp.ToAttributesViaJson(map[string]interface{}{ + "name": "projects/test-project-id/locations/us-central1/keyRings/my-keyring/cryptoKeys/my-key/cryptoKeyVersions/2", + "uniqueAttr": "us-central1|my-keyring|my-key|2", + }) + _ = attrs2.Set("uniqueAttr", "us-central1|my-keyring|my-key|2") + + item1 := &sdp.Item{ + Type: gcpshared.CloudKMSCryptoKeyVersion.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attrs1, + Scope: projectID, + Health: sdp.Health_HEALTH_OK.Enum(), + } + item2 := &sdp.Item{ + Type: gcpshared.CloudKMSCryptoKeyVersion.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attrs2, + Scope: projectID, + Health: sdp.Health_HEALTH_OK.Enum(), + } + + // Search by location|keyRing|cryptoKey (what the terraform format will be converted to) + searchCacheKey := sdpcache.CacheKeyFromParts("gcp-source", sdp.QueryMethod_SEARCH, projectID, gcpshared.CloudKMSCryptoKeyVersion.String(), "us-central1|my-keyring|my-key") + cache.StoreItem(ctx, item1, shared.DefaultCacheDuration, searchCacheKey) + cache.StoreItem(ctx, item2, shared.DefaultCacheDuration, searchCacheKey) + + loader := gcpshared.NewCloudKMSAssetLoader(nil, projectID, cache, "gcp-source", []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + + wrapper := manual.NewCloudKMSCryptoKeyVersion(loader, []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + adapter := sources.WrapperToAdapter(wrapper, cache) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + // Use terraform-style path format + terraformStyleQuery := "projects/test-project-id/locations/us-central1/keyRings/my-keyring/cryptoKeys/my-key/cryptoKeyVersions/1" + + items, qErr := searchable.Search(ctx, wrapper.Scopes()[0], terraformStyleQuery, false) + if qErr != nil { + t.Fatalf("Expected no error with terraform format, got: %v", qErr) + } + + // Verify we got at least one item back + if len(items) == 0 { + t.Fatalf("Expected at least 1 item with terraform format, got: %d", len(items)) + } + + // Verify the items have the expected unique attributes + foundVersion1 := false + for _, item := range items { + uniqueAttr, err := item.GetAttributes().Get("uniqueAttr") + if err == nil && (uniqueAttr == "us-central1|my-keyring|my-key|1" || uniqueAttr == "us-central1|my-keyring|my-key|2") { + if uniqueAttr == "us-central1|my-keyring|my-key|1" { + foundVersion1 = true + } + } + } + + if !foundVersion1 { + t.Fatalf("Expected to find version 1 in results") + } + }) + + t.Run("Search_LegacyPipeFormat", func(t *testing.T) { + cache := sdpcache.NewCache(ctx) + defer cache.Clear() + + // Pre-populate cache with CryptoKeyVersion items + attrs1, _ := sdp.ToAttributesViaJson(map[string]interface{}{ + "name": "projects/test-project-id/locations/europe-west1/keyRings/prod-keyring/cryptoKeys/prod-key/cryptoKeyVersions/1", + "uniqueAttr": "europe-west1|prod-keyring|prod-key|1", + }) + _ = attrs1.Set("uniqueAttr", "europe-west1|prod-keyring|prod-key|1") + + item1 := &sdp.Item{ + Type: gcpshared.CloudKMSCryptoKeyVersion.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attrs1, + Scope: projectID, + Health: sdp.Health_HEALTH_OK.Enum(), + } + + // Search by location|keyRing|cryptoKey (legacy format) + searchCacheKey := sdpcache.CacheKeyFromParts("gcp-source", sdp.QueryMethod_SEARCH, projectID, gcpshared.CloudKMSCryptoKeyVersion.String(), "europe-west1|prod-keyring|prod-key") + cache.StoreItem(ctx, item1, shared.DefaultCacheDuration, searchCacheKey) + + loader := gcpshared.NewCloudKMSAssetLoader(nil, projectID, cache, "gcp-source", []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + + wrapper := manual.NewCloudKMSCryptoKeyVersion(loader, []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + adapter := sources.WrapperToAdapter(wrapper, cache) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + // Use legacy pipe-separated format with multiple query parts + items, qErr := searchable.Search(ctx, wrapper.Scopes()[0], "europe-west1|prod-keyring|prod-key", false) + if qErr != nil { + t.Fatalf("Expected no error with legacy format, got: %v", qErr) + } + + if len(items) != 1 { + t.Fatalf("Expected 1 item with legacy format, got: %d", len(items)) + } + }) + + t.Run("StaticTests", func(t *testing.T) { + cache := sdpcache.NewMemoryCache() + defer cache.Clear() + // Pre-populate cache with a CryptoKeyVersion item with linked queries attrs, _ := sdp.ToAttributesViaJson(map[string]interface{}{ "name": "projects/test-project-id/locations/us/keyRings/test-keyring/cryptoKeys/test-key/cryptoKeyVersions/1", diff --git a/sources/gcp/manual/cloud-kms-crypto-key.go b/sources/gcp/manual/cloud-kms-crypto-key.go index 1dbd1eeb..39a67a0f 100644 --- a/sources/gcp/manual/cloud-kms-crypto-key.go +++ b/sources/gcp/manual/cloud-kms-crypto-key.go @@ -55,9 +55,16 @@ func (c cloudKMSCryptoKeyWrapper) PotentialLinks() map[shared.ItemType]bool { // TerraformMappings returns the Terraform mappings for the CryptoKey wrapper. func (c cloudKMSCryptoKeyWrapper) TerraformMappings() []*sdp.TerraformMapping { - // TODO: Revisit this when working on this ticket: - // https://linear.app/overmind/issue/ENG-706/fix-terraform-mappings-for-crypto-key - return nil + return []*sdp.TerraformMapping{ + { + TerraformMethod: sdp.QueryMethod_SEARCH, + // https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/kms_crypto_key + // ID format: projects/{{project}}/locations/{{location}}/keyRings/{{keyRing}}/cryptoKeys/{{name}} + // The framework automatically intercepts queries starting with "projects/" and converts + // them to GET operations by extracting the last N path parameters (based on GetLookups count). + TerraformQueryMap: "google_kms_crypto_key.id", + }, + } } // GetLookups returns the lookups for the CryptoKey wrapper. diff --git a/sources/gcp/manual/cloud-kms-crypto-key_test.go b/sources/gcp/manual/cloud-kms-crypto-key_test.go index 20948a98..9d60c092 100644 --- a/sources/gcp/manual/cloud-kms-crypto-key_test.go +++ b/sources/gcp/manual/cloud-kms-crypto-key_test.go @@ -19,7 +19,7 @@ func TestCloudKMSCryptoKey(t *testing.T) { projectID := "test-project-id" t.Run("Get_CacheHit", func(t *testing.T) { - cache := sdpcache.NewCache(ctx) + cache := sdpcache.NewMemoryCache() defer cache.Clear() // Pre-populate cache with a CryptoKey item @@ -187,6 +187,128 @@ func TestCloudKMSCryptoKey(t *testing.T) { } }) + t.Run("Search_TerraformFormat", func(t *testing.T) { + cache := sdpcache.NewCache(ctx) + defer cache.Clear() + + // Pre-populate cache with a specific CryptoKey item + // Note: Terraform queries with full path are converted to GET operations by the adapter framework + attrs, _ := sdp.ToAttributesViaJson(map[string]interface{}{ + "name": "projects/test-project-id/locations/us-central1/keyRings/my-keyring/cryptoKeys/my-key-1", + "uniqueAttr": "us-central1|my-keyring|my-key-1", + }) + _ = attrs.Set("uniqueAttr", "us-central1|my-keyring|my-key-1") + + item := &sdp.Item{ + Type: gcpshared.CloudKMSCryptoKey.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attrs, + Scope: projectID, + } + + // Store with GET cache key (terraform queries are converted to GET operations) + getCacheKey := sdpcache.CacheKeyFromParts("gcp-source", sdp.QueryMethod_GET, projectID, gcpshared.CloudKMSCryptoKey.String(), "us-central1|my-keyring|my-key-1") + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, getCacheKey) + + loader := gcpshared.NewCloudKMSAssetLoader(nil, projectID, cache, "gcp-source", []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + + wrapper := manual.NewCloudKMSCryptoKey(loader, []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + adapter := sources.WrapperToAdapter(wrapper, cache) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + // Search using terraform-style path format + // The adapter framework detects this and converts it to a GET operation + terraformID := "projects/test-project-id/locations/us-central1/keyRings/my-keyring/cryptoKeys/my-key-1" + items, qErr := searchable.Search(ctx, wrapper.Scopes()[0], terraformID, false) + if qErr != nil { + t.Fatalf("Expected no error with terraform format, got: %v", qErr) + } + + // Terraform queries with full path return 1 specific item (converted to GET) + if len(items) != 1 { + t.Fatalf("Expected 1 item with terraform format (converted to GET), got: %d", len(items)) + } + + // Verify the returned item has the correct unique attribute + uniqueAttr, err := items[0].GetAttributes().Get("uniqueAttr") + if err != nil { + t.Fatalf("Failed to get uniqueAttr: %v", err) + } + if uniqueAttr != "us-central1|my-keyring|my-key-1" { + t.Fatalf("Expected uniqueAttr 'us-central1|my-keyring|my-key-1', got: %v", uniqueAttr) + } + }) + + t.Run("Search_LegacyFormat", func(t *testing.T) { + cache := sdpcache.NewCache(ctx) + defer cache.Clear() + + // Pre-populate cache with CryptoKey items + attrs1, _ := sdp.ToAttributesViaJson(map[string]interface{}{ + "name": "projects/test-project-id/locations/us-central1/keyRings/my-keyring/cryptoKeys/my-key-1", + "uniqueAttr": "us-central1|my-keyring|my-key-1", + }) + _ = attrs1.Set("uniqueAttr", "us-central1|my-keyring|my-key-1") + + attrs2, _ := sdp.ToAttributesViaJson(map[string]interface{}{ + "name": "projects/test-project-id/locations/us-central1/keyRings/my-keyring/cryptoKeys/my-key-2", + "uniqueAttr": "us-central1|my-keyring|my-key-2", + }) + _ = attrs2.Set("uniqueAttr", "us-central1|my-keyring|my-key-2") + + item1 := &sdp.Item{ + Type: gcpshared.CloudKMSCryptoKey.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attrs1, + Scope: projectID, + } + item2 := &sdp.Item{ + Type: gcpshared.CloudKMSCryptoKey.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attrs2, + Scope: projectID, + } + + // Store with location|keyRing search key + searchCacheKey := sdpcache.CacheKeyFromParts("gcp-source", sdp.QueryMethod_SEARCH, projectID, gcpshared.CloudKMSCryptoKey.String(), "us-central1|my-keyring") + cache.StoreItem(ctx, item1, shared.DefaultCacheDuration, searchCacheKey) + cache.StoreItem(ctx, item2, shared.DefaultCacheDuration, searchCacheKey) + + loader := gcpshared.NewCloudKMSAssetLoader(nil, projectID, cache, "gcp-source", []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + + wrapper := manual.NewCloudKMSCryptoKey(loader, []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + adapter := sources.WrapperToAdapter(wrapper, cache) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + // Search using legacy pipe format + legacyQuery := "us-central1|my-keyring" + items, qErr := searchable.Search(ctx, wrapper.Scopes()[0], legacyQuery, false) + if qErr != nil { + t.Fatalf("Expected no error with legacy format, got: %v", qErr) + } + + if len(items) != 2 { + t.Fatalf("Expected 2 items with legacy format, got: %d", len(items)) + } + + // Verify the returned items have the correct unique attributes + uniqueAttr1, err := items[0].GetAttributes().Get("uniqueAttr") + if err != nil { + t.Fatalf("Failed to get uniqueAttr from item 1: %v", err) + } + if uniqueAttr1 != "us-central1|my-keyring|my-key-1" { + t.Fatalf("Expected uniqueAttr 'us-central1|my-keyring|my-key-1', got: %v", uniqueAttr1) + } + }) + t.Run("List_Unsupported", func(t *testing.T) { cache := sdpcache.NewCache(ctx) defer cache.Clear() diff --git a/sources/gcp/manual/cloud-kms-key-ring.go b/sources/gcp/manual/cloud-kms-key-ring.go index a5cd48e2..fcba983a 100644 --- a/sources/gcp/manual/cloud-kms-key-ring.go +++ b/sources/gcp/manual/cloud-kms-key-ring.go @@ -57,8 +57,12 @@ func (c cloudKMSKeyRingWrapper) PotentialLinks() map[shared.ItemType]bool { func (c cloudKMSKeyRingWrapper) TerraformMappings() []*sdp.TerraformMapping { return []*sdp.TerraformMapping{ { - TerraformMethod: sdp.QueryMethod_GET, - TerraformQueryMap: "google_kms_key_ring.name", + TerraformMethod: sdp.QueryMethod_SEARCH, + // https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/kms_key_ring + // ID format: projects/{{project}}/locations/{{location}}/keyRings/{{name}} + // The framework automatically intercepts queries starting with "projects/" and converts + // them to GET operations by extracting the last N path parameters (based on GetLookups count). + TerraformQueryMap: "google_kms_key_ring.id", }, } } diff --git a/sources/gcp/manual/cloud-kms-key-ring_test.go b/sources/gcp/manual/cloud-kms-key-ring_test.go index f4f79a97..db4b413c 100644 --- a/sources/gcp/manual/cloud-kms-key-ring_test.go +++ b/sources/gcp/manual/cloud-kms-key-ring_test.go @@ -19,7 +19,7 @@ func TestCloudKMSKeyRing(t *testing.T) { projectID := "test-project-id" t.Run("Get_CacheHit", func(t *testing.T) { - cache := sdpcache.NewCache(ctx) + cache := sdpcache.NewMemoryCache() defer cache.Clear() // Pre-populate cache with a KeyRing item (simulating what the loader would do) @@ -64,7 +64,7 @@ func TestCloudKMSKeyRing(t *testing.T) { }) t.Run("Get_CacheMiss_NotFound", func(t *testing.T) { - cache := sdpcache.NewCache(ctx) + cache := sdpcache.NewMemoryCache() defer cache.Clear() // Pre-populate cache with a NOTFOUND error to simulate item not existing @@ -95,7 +95,7 @@ func TestCloudKMSKeyRing(t *testing.T) { }) t.Run("List_CacheHit", func(t *testing.T) { - cache := sdpcache.NewCache(ctx) + cache := sdpcache.NewMemoryCache() defer cache.Clear() // Pre-populate cache with KeyRing items under LIST cache key @@ -149,7 +149,7 @@ func TestCloudKMSKeyRing(t *testing.T) { }) t.Run("List_CacheHit_Empty", func(t *testing.T) { - cache := sdpcache.NewCache(ctx) + cache := sdpcache.NewMemoryCache() defer cache.Clear() // Store NOTFOUND error in cache to simulate empty result @@ -181,7 +181,7 @@ func TestCloudKMSKeyRing(t *testing.T) { } }) - t.Run("Search_CacheHit", func(t *testing.T) { + t.Run("Search_CacheHit_ByLocation", func(t *testing.T) { cache := sdpcache.NewCache(ctx) defer cache.Clear() @@ -222,10 +222,117 @@ func TestCloudKMSKeyRing(t *testing.T) { } }) - t.Run("StaticTests", func(t *testing.T) { + t.Run("Search_TerraformFormat", func(t *testing.T) { cache := sdpcache.NewCache(ctx) defer cache.Clear() + // Pre-populate cache with KeyRing item + attrs, _ := sdp.ToAttributesViaJson(map[string]interface{}{ + "name": "projects/test-project-id/locations/us-central1/keyRings/my-keyring", + "uniqueAttr": "us-central1|my-keyring", + }) + _ = attrs.Set("uniqueAttr", "us-central1|my-keyring") + + item := &sdp.Item{ + Type: gcpshared.CloudKMSKeyRing.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attrs, + Scope: projectID, + } + + // Store with location-based search key (terraform format is converted to location) + searchCacheKey := sdpcache.CacheKeyFromParts("gcp-source", sdp.QueryMethod_SEARCH, projectID, gcpshared.CloudKMSKeyRing.String(), "us-central1") + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, searchCacheKey) + + loader := gcpshared.NewCloudKMSAssetLoader(nil, projectID, cache, "gcp-source", []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + + wrapper := manual.NewCloudKMSKeyRing(loader, []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + adapter := sources.WrapperToAdapter(wrapper, cache) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + // Search using terraform-style path format + // The SearchStream will extract the location and search by that + terraformID := "projects/test-project-id/locations/us-central1/keyRings/my-keyring" + items, qErr := searchable.Search(ctx, wrapper.Scopes()[0], terraformID, false) + if qErr != nil { + t.Fatalf("Expected no error with terraform format, got: %v", qErr) + } + + if len(items) != 1 { + t.Fatalf("Expected 1 item with terraform format, got: %d", len(items)) + } + + // Verify the returned item has the correct unique attribute + uniqueAttr, err := items[0].GetAttributes().Get("uniqueAttr") + if err != nil { + t.Fatalf("Failed to get uniqueAttr: %v", err) + } + if uniqueAttr != "us-central1|my-keyring" { + t.Fatalf("Expected uniqueAttr 'us-central1|my-keyring', got: %v", uniqueAttr) + } + }) + + t.Run("Search_LegacyLocationFormat", func(t *testing.T) { + cache := sdpcache.NewCache(ctx) + defer cache.Clear() + + // Pre-populate cache with KeyRing item + attrs, _ := sdp.ToAttributesViaJson(map[string]interface{}{ + "name": "projects/test-project-id/locations/us-central1/keyRings/my-keyring", + "uniqueAttr": "us-central1|my-keyring", + }) + _ = attrs.Set("uniqueAttr", "us-central1|my-keyring") + + item := &sdp.Item{ + Type: gcpshared.CloudKMSKeyRing.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attrs, + Scope: projectID, + } + + // Store with location-based search key + searchCacheKey := sdpcache.CacheKeyFromParts("gcp-source", sdp.QueryMethod_SEARCH, projectID, gcpshared.CloudKMSKeyRing.String(), "us-central1") + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, searchCacheKey) + + loader := gcpshared.NewCloudKMSAssetLoader(nil, projectID, cache, "gcp-source", []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + + wrapper := manual.NewCloudKMSKeyRing(loader, []gcpshared.LocationInfo{gcpshared.NewProjectLocation(projectID)}) + adapter := sources.WrapperToAdapter(wrapper, cache) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + // Search using legacy location format + legacyQuery := "us-central1" + items, qErr := searchable.Search(ctx, wrapper.Scopes()[0], legacyQuery, false) + if qErr != nil { + t.Fatalf("Expected no error with legacy format, got: %v", qErr) + } + + if len(items) != 1 { + t.Fatalf("Expected 1 item with legacy format, got: %d", len(items)) + } + + // Verify the returned item has the correct unique attribute + uniqueAttr, err := items[0].GetAttributes().Get("uniqueAttr") + if err != nil { + t.Fatalf("Failed to get uniqueAttr: %v", err) + } + if uniqueAttr != "us-central1|my-keyring" { + t.Fatalf("Expected uniqueAttr 'us-central1|my-keyring', got: %v", uniqueAttr) + } + }) + + t.Run("StaticTests", func(t *testing.T) { + cache := sdpcache.NewMemoryCache() + defer cache.Clear() + // Pre-populate cache with a KeyRing item attrs, _ := sdp.ToAttributesViaJson(map[string]interface{}{ "name": "projects/test-project-id/locations/us/keyRings/test-keyring", diff --git a/sources/gcp/manual/iam-service-account-key.go b/sources/gcp/manual/iam-service-account-key.go index bdf2c6d4..fd18c5be 100644 --- a/sources/gcp/manual/iam-service-account-key.go +++ b/sources/gcp/manual/iam-service-account-key.go @@ -55,9 +55,11 @@ func (c iamServiceAccountKeyWrapper) PotentialLinks() map[shared.ItemType]bool { func (c iamServiceAccountKeyWrapper) TerraformMappings() []*sdp.TerraformMapping { return []*sdp.TerraformMapping{ { - TerraformMethod: sdp.QueryMethod_GET, - // https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/compute_snapshot#argument-reference - // projects/{{project}}/serviceAccounts/{{account}}/keys/{{key}} + TerraformMethod: sdp.QueryMethod_SEARCH, + // https://registry.terraform.io/providers/hashicorp/google/latest/docs/resources/service_account_key + // ID format: projects/{{project}}/serviceAccounts/{{account}}/keys/{{key}} + // The framework automatically intercepts queries starting with "projects/" and converts + // them to GET operations by extracting the last N path parameters (based on GetLookups count). TerraformQueryMap: "google_service_account_key.id", }, } diff --git a/sources/gcp/proc/proc.go b/sources/gcp/proc/proc.go index 40e3dbed..49f8c3b0 100644 --- a/sources/gcp/proc/proc.go +++ b/sources/gcp/proc/proc.go @@ -320,6 +320,7 @@ func init() { sdpcache.NewNoOpCache(), // no-op cache for metadata registration ) if err != nil { + // docs generation should fail if there are errors creating adapters panic(fmt.Errorf("error creating adapters: %w", err)) } @@ -330,14 +331,11 @@ func init() { log.Debug("Registered GCP source metadata", " with ", len(Metadata.AllAdapterMetadata()), " adapters") } -func Initialize(ctx context.Context, ec *discovery.EngineConfig, cfg *GCPConfig) (*discovery.Engine, error) { - engine, err := discovery.NewEngine(ec) - if err != nil { - return nil, fmt.Errorf("error initializing Engine: %w", err) - } - - var healthChecker *ProjectHealthChecker - +// InitializeAdapters adds GCP adapters to an existing engine. This is a single-attempt +// function; retry logic is handled by the caller via Engine.InitialiseAdapters. +// +// cfg must not be nil — call ConfigFromViper() first for config validation. +func InitializeAdapters(ctx context.Context, engine *discovery.Engine, cfg *GCPConfig) error { // ReadinessCheck verifies adapters are healthy by using a CloudResourceManagerProject adapter // Timeout is handled by SendHeartbeat, HTTP handlers rely on request context engine.SetReadinessCheck(func(ctx context.Context) error { @@ -364,215 +362,186 @@ func Initialize(ctx context.Context, ec *discovery.EngineConfig, cfg *GCPConfig) // Create a shared cache for all adapters in this source sharedCache := sdpcache.NewCache(ctx) - err = func() error { - var logmsg string - // Use provided config, otherwise fall back to viper - if cfg != nil { - logmsg = "Using directly provided config" - } else { - var err error - cfg, err = readConfig() - if err != nil { - return fmt.Errorf("error creating config from command line: %w", err) - } - logmsg = "Using config from viper" + // Determine which projects to use based on the parent configuration + var projectIDs []string + if cfg.Parent == "" { + // No parent specified - discover all accessible projects + log.WithFields(log.Fields{ + "ovm.source.type": "gcp", + }).Info("No parent specified, discovering all accessible projects") + discoveredProjects, err := discoverProjects(ctx, cfg.ImpersonationServiceAccountEmail) + if err != nil { + return fmt.Errorf("error discovering projects: %w", err) } - // Determine which projects to use based on the parent configuration - var projectIDs []string - if cfg.Parent == "" { - // No parent specified - discover all accessible projects - log.WithFields(log.Fields{ - "ovm.source.type": "gcp", - }).Info("No parent specified, discovering all accessible projects") + projectIDs = discoveredProjects + } else { + // Parent is specified - determine its type and discover accordingly + parentType, err := detectParentType(cfg.Parent) + if err != nil { + return fmt.Errorf("error detecting parent type: %w", err) + } - discoveredProjects, err := discoverProjects(ctx, cfg.ImpersonationServiceAccountEmail) - if err != nil { - return fmt.Errorf("error discovering projects: %w", err) - } + normalizedParent, err := normalizeParent(cfg.Parent, parentType) + if err != nil { + return fmt.Errorf("error normalizing parent: %w", err) + } - projectIDs = discoveredProjects - } else { - // Parent is specified - determine its type and discover accordingly - parentType, err := detectParentType(cfg.Parent) - if err != nil { - return fmt.Errorf("error detecting parent type: %w", err) - } + switch parentType { + case ParentTypeProject: + // Single project - no discovery needed + log.WithFields(log.Fields{ + "ovm.source.type": "gcp", + "ovm.source.parent": cfg.Parent, + "ovm.source.project_id": normalizedParent, + }).Info("Using specified project") + projectIDs = []string{normalizedParent} + + case ParentTypeOrganization, ParentTypeFolder: + // Organization or folder - discover all projects within it + log.WithFields(log.Fields{ + "ovm.source.type": "gcp", + "ovm.source.parent": cfg.Parent, + "parent_type": parentType, + }).Info("Discovering projects under parent") - normalizedParent, err := normalizeParent(cfg.Parent, parentType) + discoveredProjects, err := discoverProjectsUnderSpecificParent(ctx, cfg.Parent, cfg.ImpersonationServiceAccountEmail) if err != nil { - return fmt.Errorf("error normalizing parent: %w", err) + return fmt.Errorf("error discovering projects under parent %s: %w", cfg.Parent, err) } - switch parentType { - case ParentTypeProject: - // Single project - no discovery needed - log.WithFields(log.Fields{ - "ovm.source.type": "gcp", - "ovm.source.parent": cfg.Parent, - "ovm.source.project_id": normalizedParent, - }).Info("Using specified project") - projectIDs = []string{normalizedParent} - - case ParentTypeOrganization, ParentTypeFolder: - // Organization or folder - discover all projects within it - log.WithFields(log.Fields{ - "ovm.source.type": "gcp", - "ovm.source.parent": cfg.Parent, - "parent_type": parentType, - }).Info("Discovering projects under parent") - - discoveredProjects, err := discoverProjectsUnderSpecificParent(ctx, cfg.Parent, cfg.ImpersonationServiceAccountEmail) - if err != nil { - return fmt.Errorf("error discovering projects under parent %s: %w", cfg.Parent, err) - } - - if len(discoveredProjects) == 0 { - return fmt.Errorf("no accessible projects found under parent %s. Please ensure the service account has the 'resourcemanager.projects.list' permission via the 'roles/browser' predefined GCP role", cfg.Parent) - } - - projectIDs = discoveredProjects - - case ParentTypeUnknown: - return fmt.Errorf("unknown parent type for parent: %s", cfg.Parent) - - default: - return fmt.Errorf("unknown parent type for parent: %s", cfg.Parent) + if len(discoveredProjects) == 0 { + return fmt.Errorf("no accessible projects found under parent %s. Please ensure the service account has the 'resourcemanager.projects.list' permission via the 'roles/browser' predefined GCP role", cfg.Parent) } - } - logFields := log.Fields{ - "ovm.source.type": "gcp", - "ovm.source.project_count": len(projectIDs), - "ovm.source.regions": cfg.Regions, - "ovm.source.zones": cfg.Zones, - "ovm.source.impersonation-service-account-email": cfg.ImpersonationServiceAccountEmail, - } - if cfg.Parent == "" { - logFields["ovm.source.parent"] = "" - } else { - logFields["ovm.source.parent"] = cfg.Parent - } - if cfg.ProjectID != "" { - logFields["ovm.source.project_id"] = cfg.ProjectID - } - log.WithFields(logFields).Info(logmsg) + projectIDs = discoveredProjects - // If still no regions/zones this is no valid config. - if len(cfg.Regions) == 0 && len(cfg.Zones) == 0 { - return fmt.Errorf("GCP source must specify at least one region or zone") + case ParentTypeUnknown: + return fmt.Errorf("unknown parent type for parent: %s", cfg.Parent) + + default: + return fmt.Errorf("unknown parent type for parent: %s", cfg.Parent) } + } - linker := gcpshared.NewLinker() + logFields := log.Fields{ + "ovm.source.type": "gcp", + "ovm.source.project_count": len(projectIDs), + "ovm.source.regions": cfg.Regions, + "ovm.source.zones": cfg.Zones, + "ovm.source.impersonation-service-account-email": cfg.ImpersonationServiceAccountEmail, + } + if cfg.Parent == "" { + logFields["ovm.source.parent"] = "" + } else { + logFields["ovm.source.parent"] = cfg.Parent + } + if cfg.ProjectID != "" { + logFields["ovm.source.project_id"] = cfg.ProjectID + } + log.WithFields(logFields).Info("Got config") - // Build LocationInfo slices for all projects, regions, and zones - projectLocations := make([]gcpshared.LocationInfo, 0, len(projectIDs)) - for _, projectID := range projectIDs { - projectLocations = append(projectLocations, gcpshared.NewProjectLocation(projectID)) - } + // If still no regions/zones this is no valid config. + if len(cfg.Regions) == 0 && len(cfg.Zones) == 0 { + return fmt.Errorf("GCP source must specify at least one region or zone") + } - regionLocations := make([]gcpshared.LocationInfo, 0, len(projectIDs)*len(cfg.Regions)) - for _, projectID := range projectIDs { - for _, region := range cfg.Regions { - regionLocations = append(regionLocations, gcpshared.NewRegionalLocation(projectID, region)) - } - } + linker := gcpshared.NewLinker() - zoneLocations := make([]gcpshared.LocationInfo, 0, len(projectIDs)*len(cfg.Zones)) - for _, projectID := range projectIDs { - for _, zone := range cfg.Zones { - zoneLocations = append(zoneLocations, gcpshared.NewZonalLocation(projectID, zone)) - } - } + // Build LocationInfo slices for all projects, regions, and zones + projectLocations := make([]gcpshared.LocationInfo, 0, len(projectIDs)) + for _, projectID := range projectIDs { + projectLocations = append(projectLocations, gcpshared.NewProjectLocation(projectID)) + } - // Create adapters once for all projects using pre-built LocationInfo - log.WithFields(log.Fields{ - "ovm.source.type": "gcp", - "ovm.source.project_count": len(projectIDs), - }).Debug("Creating multi-project adapters") - - allAdapters, err := adapters( - ctx, - projectLocations, - regionLocations, - zoneLocations, - cfg.ImpersonationServiceAccountEmail, - linker, - true, - sharedCache, - ) - if err != nil { - return fmt.Errorf("error creating discovery adapters: %w", err) + regionLocations := make([]gcpshared.LocationInfo, 0, len(projectIDs)*len(cfg.Regions)) + for _, projectID := range projectIDs { + for _, region := range cfg.Regions { + regionLocations = append(regionLocations, gcpshared.NewRegionalLocation(projectID, region)) } + } - // Find the single multi-project CloudResourceManagerProject adapter - var cloudResourceManagerProjectAdapter discovery.Adapter - for _, adapter := range allAdapters { - if adapter.Type() == gcpshared.CloudResourceManagerProject.String() { - cloudResourceManagerProjectAdapter = adapter - break - } + zoneLocations := make([]gcpshared.LocationInfo, 0, len(projectIDs)*len(cfg.Zones)) + for _, projectID := range projectIDs { + for _, zone := range cfg.Zones { + zoneLocations = append(zoneLocations, gcpshared.NewZonalLocation(projectID, zone)) } + } - if cloudResourceManagerProjectAdapter == nil { - return fmt.Errorf("cloud resource manager project adapter not found") - } + // Create adapters once for all projects using pre-built LocationInfo + log.WithFields(log.Fields{ + "ovm.source.type": "gcp", + "ovm.source.project_count": len(projectIDs), + }).Debug("Creating multi-project adapters") - // Create health checker with single multi-project adapter and 5 minute cache duration - healthChecker = NewProjectHealthChecker( - projectIDs, - cloudResourceManagerProjectAdapter, - 5*time.Minute, - ) + allAdapters, err := adapters( + ctx, + projectLocations, + regionLocations, + zoneLocations, + cfg.ImpersonationServiceAccountEmail, + linker, + true, + sharedCache, + ) + if err != nil { + return fmt.Errorf("error creating discovery adapters: %w", err) + } - // Run initial permission check before starting the source to fail fast if - // we don't have the required permissions. This validates that we can access - // the Cloud Resource Manager API for all configured projects. - result, err := healthChecker.Check(ctx) - if err != nil { - log.WithContext(ctx).WithError(err).WithFields(log.Fields{ - "ovm.source.type": "gcp", - "ovm.source.success_count": result.SuccessCount, - "ovm.source.failure_count": result.FailureCount, - "ovm.source.project_count": len(projectIDs), - }).Error("Permission check failed for some projects") - } else { - log.WithFields(log.Fields{ - "ovm.source.type": "gcp", - "ovm.source.success_count": result.SuccessCount, - "ovm.source.project_count": len(projectIDs), - }).Info("All projects passed permission checks") + // Find the single multi-project CloudResourceManagerProject adapter + var cloudResourceManagerProjectAdapter discovery.Adapter + for _, adapter := range allAdapters { + if adapter.Type() == gcpshared.CloudResourceManagerProject.String() { + cloudResourceManagerProjectAdapter = adapter + break } + } - // Add the adapters to the engine - err = engine.AddAdapters(allAdapters...) - if err != nil { - return fmt.Errorf("error adding adapters to engine: %w", err) - } + if cloudResourceManagerProjectAdapter == nil { + return fmt.Errorf("cloud resource manager project adapter not found") + } - return nil - }() + // Create health checker with single multi-project adapter and 5 minute cache duration + healthChecker := NewProjectHealthChecker( + projectIDs, + cloudResourceManagerProjectAdapter, + 5*time.Minute, + ) + // Run initial permission check before starting the source to fail fast if + // we don't have the required permissions. This validates that we can access + // the Cloud Resource Manager API for all configured projects. + result, err := healthChecker.Check(ctx) if err != nil { - log.WithError(err).Debug("Error initializing GCP source") - return nil, fmt.Errorf("error initializing GCP source: %w", err) + log.WithContext(ctx).WithError(err).WithFields(log.Fields{ + "ovm.source.type": "gcp", + "ovm.source.success_count": result.SuccessCount, + "ovm.source.failure_count": result.FailureCount, + "ovm.source.project_count": len(projectIDs), + }).Error("Permission check failed for some projects") + } else { + log.WithFields(log.Fields{ + "ovm.source.type": "gcp", + "ovm.source.success_count": result.SuccessCount, + "ovm.source.project_count": len(projectIDs), + }).Info("All projects passed permission checks") } - // Start sending heartbeats after adapters are successfully added - // This ensures the first heartbeat has adapters available for readiness checks - engine.StartSendingHeartbeats(ctx) - brokenHeart := engine.SendHeartbeat(ctx, nil) // Send the error immediately through the custom health check func - if brokenHeart != nil { - log.WithError(brokenHeart).Error("Error sending heartbeat") + // Add the adapters to the engine + err = engine.AddAdapters(allAdapters...) + if err != nil { + return fmt.Errorf("error adding adapters to engine: %w", err) } log.Debug("Sources initialized") - // If there is no error then return the engine - return engine, nil + return nil } -func readConfig() (*GCPConfig, error) { +// ConfigFromViper reads and validates the GCP configuration from viper flags. +// This performs local validation only (no API calls) and should be called +// before InitializeAdapters to catch permanent config errors early. +func ConfigFromViper() (*GCPConfig, error) { parent := viper.GetString("gcp-parent") projectID := viper.GetString("gcp-project-id") diff --git a/sources/gcp/shared/utils.go b/sources/gcp/shared/utils.go index 641bddd3..cdfbd5db 100644 --- a/sources/gcp/shared/utils.go +++ b/sources/gcp/shared/utils.go @@ -130,6 +130,27 @@ func ZoneToRegion(zone string) string { return strings.Join(parts[:len(parts)-1], "-") } +// isProjectNumber returns true if the project identifier appears to be a +// GCP project number (all digits) rather than a project ID. Project IDs +// must start with a letter per GCP rules. +// +// We use a simple loop instead of a regex (e.g., `^\d+$`) because: +// - It's more idiomatic Go for simple character validation +// - Avoids regex compilation/matching overhead (even pre-compiled) +// - More readable for maintainers unfamiliar with regex +// - Sufficient for the straightforward "all digits" check +func isProjectNumber(projectID string) bool { + if projectID == "" { + return false + } + for _, r := range projectID { + if r < '0' || r > '9' { + return false + } + } + return true +} + // ExtractScopeFromURI extracts the scope from a GCP resource URI. // It supports various URL formats including full HTTPS URLs, full resource names, // service destination formats, and bare paths. @@ -183,6 +204,13 @@ func ExtractScopeFromURI(ctx context.Context, uri string) (string, error) { return "", err } + // When URI uses project number instead of project ID, we cannot map to + // adapter scopes (which use project IDs). Return wildcard so the query + // is broadcast to all adapters. + if isProjectNumber(projectID) { + return "*", nil + } + // Check for conflicting location specifiers if zone != "" && region != "" { err := fmt.Errorf("cannot determine scope: both zones and regions found in URI: %s", uri) diff --git a/sources/gcp/shared/utils_test.go b/sources/gcp/shared/utils_test.go index 70d726a1..0b232771 100644 --- a/sources/gcp/shared/utils_test.go +++ b/sources/gcp/shared/utils_test.go @@ -590,6 +590,27 @@ func TestExtractScopeFromURI(t *testing.T) { uri: "https://pubsub.googleapis.com/v1/projects/my-project/topics/my-topic", expected: "my-project", }, + // Project number cases (wildcard scope) + { + name: "Project number - Global resource", + uri: "projects/96771641962/global/instanceTemplates/my-template", + expected: "*", + }, + { + name: "Project number - Regional resource", + uri: "projects/96771641962/regions/us-central1/subnetworks/my-subnet", + expected: "*", + }, + { + name: "Project number - Zonal resource", + uri: "projects/96771641962/zones/us-central1-a/disks/my-disk", + expected: "*", + }, + { + name: "Project number - Short numeric", + uri: "projects/123/global/networks/my-network", + expected: "*", + }, // Error cases { name: "Error - Empty URI", diff --git a/sources/transformer.go b/sources/transformer.go index cef07012..4d673366 100644 --- a/sources/transformer.go +++ b/sources/transformer.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "strings" - "sync" "buf.build/go/protovalidate" log "github.com/sirupsen/logrus" @@ -218,22 +217,6 @@ type standardSearchableListableAdapterImpl struct { // Standard Adapter Core methods // ***************************** -var ( - noOpCacheTransformerOnce sync.Once - noOpCacheTransformer sdpcache.Cache -) - -// Cache returns the cache of the adapter. -func (s *standardAdapterCore) Cache() sdpcache.Cache { - if s.cache == nil { - noOpCacheTransformerOnce.Do(func() { - noOpCacheTransformer = sdpcache.NewNoOpCache() - }) - return noOpCacheTransformer - } - return s.cache -} - // Type returns the type of the adapter. func (s *standardAdapterCore) Type() string { return s.wrapper.Type() @@ -275,7 +258,7 @@ func (s *standardAdapterCore) Get(ctx context.Context, scope string, query strin return nil, err } - cacheHit, ck, cachedItem, qErr, done := s.Cache().Lookup( + cacheHit, ck, cachedItem, qErr, done := s.cache.Lookup( ctx, s.Name(), sdp.QueryMethod_GET, @@ -311,12 +294,12 @@ func (s *standardAdapterCore) Get(ctx context.Context, scope string, query strin item, err := s.wrapper.Get(ctx, scope, queryParts...) if err != nil { - s.Cache().StoreError(ctx, err, shared.DefaultCacheDuration, ck) + s.cache.StoreError(ctx, err, shared.DefaultCacheDuration, ck) return nil, err } // Store in cache after successful get - s.Cache().StoreItem(ctx, item, shared.DefaultCacheDuration, ck) + s.cache.StoreItem(ctx, item, shared.DefaultCacheDuration, ck) return item, nil } @@ -386,7 +369,7 @@ func (s *standardListableAdapterImpl) List(ctx context.Context, scope string, ig return nil, nil } - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup( + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup( ctx, s.Name(), sdp.QueryMethod_LIST, @@ -413,12 +396,12 @@ func (s *standardListableAdapterImpl) List(ctx context.Context, scope string, ig items, err := s.listable.List(ctx, scope) if err != nil { - s.Cache().StoreError(ctx, err, shared.DefaultCacheDuration, ck) + s.cache.StoreError(ctx, err, shared.DefaultCacheDuration, ck) return nil, err } for _, item := range items { - s.Cache().StoreItem(ctx, item, shared.DefaultCacheDuration, ck) + s.cache.StoreItem(ctx, item, shared.DefaultCacheDuration, ck) } return items, nil @@ -435,7 +418,7 @@ func (s *standardListableAdapterImpl) ListStream(ctx context.Context, scope stri return } - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup( + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup( ctx, s.Name(), sdp.QueryMethod_LIST, @@ -644,7 +627,7 @@ func (s *standardSearchableAdapterImpl) SearchStream(ctx context.Context, scope return } - cacheHit, ck, cachedItems, qErr, done := s.Cache().Lookup( + cacheHit, ck, cachedItems, qErr, done := s.cache.Lookup( ctx, s.Name(), sdp.QueryMethod_SEARCH, @@ -700,7 +683,7 @@ func (s *standardSearchableAdapterImpl) SearchStream(ctx context.Context, scope return } - s.Cache().StoreItem(ctx, item, shared.DefaultCacheDuration, ck) + s.cache.StoreItem(ctx, item, shared.DefaultCacheDuration, ck) stream.SendItem(item) return @@ -745,7 +728,7 @@ func (s *standardSearchableAdapterImpl) SearchStream(ctx context.Context, scope return } - s.Cache().StoreItem(ctx, item, shared.DefaultCacheDuration, ck) + s.cache.StoreItem(ctx, item, shared.DefaultCacheDuration, ck) stream.SendItem(item) return diff --git a/stdlib-source/adapters/dns.go b/stdlib-source/adapters/dns.go index eddab2b5..c5811aa6 100644 --- a/stdlib-source/adapters/dns.go +++ b/stdlib-source/adapters/dns.go @@ -7,7 +7,6 @@ import ( "net" "sort" "strings" - "sync" "time" "github.com/cenkalti/backoff/v5" @@ -32,23 +31,16 @@ type DNSAdapter struct { cache sdpcache.Cache // This is mandatory } -const dnsCacheDuration = 5 * time.Minute - -var ( - noOpCacheDNSOnce sync.Once - noOpCacheDNS sdpcache.Cache -) - -func (d *DNSAdapter) Cache() sdpcache.Cache { - if d.cache == nil { - noOpCacheDNSOnce.Do(func() { - noOpCacheDNS = sdpcache.NewNoOpCache() - }) - return noOpCacheDNS +// NewDNSAdapterForHealthCheck creates a DNSAdapter with a NoOpCache for use in health checks. +// This is useful when you need a DNSAdapter but don't need caching functionality. +func NewDNSAdapterForHealthCheck() *DNSAdapter { + return &DNSAdapter{ + cache: sdpcache.NewNoOpCache(), } - return d.cache } +const dnsCacheDuration = 5 * time.Minute + var DefaultServers = []string{ "169.254.169.253:53", // Route 53 default resolver. See https://docs.aws.amazon.com/vpc/latest/userguide/AmazonDNS-concepts.html#AmazonDNS "1.1.1.1:53", @@ -139,7 +131,7 @@ func (d *DNSAdapter) Get(ctx context.Context, scope string, query string, ignore var qErr *sdp.QueryError var done func() - cacheHit, ck, cachedItems, qErr, done = d.Cache().Lookup(ctx, d.Name(), sdp.QueryMethod_GET, scope, d.Type(), query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done = d.cache.Lookup(ctx, d.Name(), sdp.QueryMethod_GET, scope, d.Type(), query, ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -169,7 +161,7 @@ func (d *DNSAdapter) Get(ctx context.Context, scope string, query string, ignore ItemType: d.Type(), } } - d.Cache().StoreItem(ctx, items[0], dnsCacheDuration, ck) + d.cache.StoreItem(ctx, items[0], dnsCacheDuration, ck) return items[0], nil } @@ -210,7 +202,7 @@ func (d *DNSAdapter) Search(ctx context.Context, scope string, query string, ign var qErr *sdp.QueryError var done func() if !ignoreCache { - cacheHit, _, cachedItems, qErr, done = d.Cache().Lookup(ctx, d.Name(), sdp.QueryMethod_SEARCH, scope, d.Type(), query, ignoreCache) + cacheHit, _, cachedItems, qErr, done = d.cache.Lookup(ctx, d.Name(), sdp.QueryMethod_SEARCH, scope, d.Type(), query, ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -228,7 +220,7 @@ func (d *DNSAdapter) Search(ctx context.Context, scope string, query string, ign // If it's an IP then we want to run a reverse lookup items, err := d.MakeReverseQuery(ctx, query) if err != nil { - d.Cache().StoreError(ctx, err, dnsCacheDuration, ck) + d.cache.StoreError(ctx, err, dnsCacheDuration, ck) return nil, err } @@ -241,12 +233,12 @@ func (d *DNSAdapter) Search(ctx context.Context, scope string, query string, ign SourceName: d.Name(), ItemType: d.Type(), } - d.Cache().StoreError(ctx, notFoundErr, dnsCacheDuration, ck) + d.cache.StoreError(ctx, notFoundErr, dnsCacheDuration, ck) return nil, notFoundErr } for _, item := range items { - d.Cache().StoreItem(ctx, item, dnsCacheDuration, ck) + d.cache.StoreItem(ctx, item, dnsCacheDuration, ck) } return items, nil @@ -259,12 +251,12 @@ func (d *DNSAdapter) Search(ctx context.Context, scope string, query string, ign items, err := d.MakeQuery(ctx, query) if err != nil { - d.Cache().StoreError(ctx, err, dnsCacheDuration, ck) + d.cache.StoreError(ctx, err, dnsCacheDuration, ck) return nil, err } for _, item := range items { - d.Cache().StoreItem(ctx, item, dnsCacheDuration, ck) + d.cache.StoreItem(ctx, item, dnsCacheDuration, ck) } return items, nil diff --git a/stdlib-source/adapters/dns_test.go b/stdlib-source/adapters/dns_test.go index 1a0c15bf..ca72183d 100644 --- a/stdlib-source/adapters/dns_test.go +++ b/stdlib-source/adapters/dns_test.go @@ -8,12 +8,14 @@ import ( "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) func TestSearch(t *testing.T) { t.Parallel() s := DNSAdapter{ + cache: sdpcache.NewNoOpCache(), Servers: []string{ "1.1.1.1:53", "8.8.8.8:53", @@ -112,7 +114,9 @@ func TestDnsGet(t *testing.T) { t.Skip("No internet connection detected") } - src := DNSAdapter{} + src := DNSAdapter{ + cache: sdpcache.NewNoOpCache(), + } t.Run("working request", func(t *testing.T) { item, err := src.Get(context.Background(), "global", "one.one.one.one", false) diff --git a/stdlib-source/adapters/http.go b/stdlib-source/adapters/http.go index 3ad07bdc..90ec4aef 100644 --- a/stdlib-source/adapters/http.go +++ b/stdlib-source/adapters/http.go @@ -10,7 +10,6 @@ import ( "net/url" "runtime" "strings" - "sync" "time" "github.com/overmindtech/cli/sdp-go" @@ -80,21 +79,6 @@ type HTTPAdapter struct { const httpCacheDuration = 5 * time.Minute -var ( - noOpCacheHTTPOnce sync.Once - noOpCacheHTTP sdpcache.Cache -) - -func (s *HTTPAdapter) Cache() sdpcache.Cache { - if s.cache == nil { - noOpCacheHTTPOnce.Do(func() { - noOpCacheHTTP = sdpcache.NewNoOpCache() - }) - return noOpCacheHTTP - } - return s.cache -} - // Type The type of items that this adapter is capable of finding func (s *HTTPAdapter) Type() string { return "http" @@ -174,7 +158,7 @@ func (s *HTTPAdapter) Get(ctx context.Context, scope string, query string, ignor ErrorString: err.Error(), Scope: scope, } - s.Cache().StoreError(ctx, err, httpCacheDuration, ck) + s.cache.StoreError(ctx, err, httpCacheDuration, ck) return nil, err } } @@ -185,7 +169,7 @@ func (s *HTTPAdapter) Get(ctx context.Context, scope string, query string, ignor var qErr *sdp.QueryError var done func() - cacheHit, ck, cachedItems, qErr, done = s.Cache().Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.Type(), query, ignoreCache) + cacheHit, ck, cachedItems, qErr, done = s.cache.Lookup(ctx, s.Name(), sdp.QueryMethod_GET, scope, s.Type(), query, ignoreCache) defer done() if qErr != nil { return nil, qErr @@ -221,7 +205,7 @@ func (s *HTTPAdapter) Get(ctx context.Context, scope string, query string, ignor ErrorString: err.Error(), Scope: scope, } - s.Cache().StoreError(ctx, err, httpCacheDuration, ck) + s.cache.StoreError(ctx, err, httpCacheDuration, ck) return nil, err } @@ -238,7 +222,7 @@ func (s *HTTPAdapter) Get(ctx context.Context, scope string, query string, ignor ErrorString: err.Error(), Scope: scope, } - s.Cache().StoreError(ctx, err, httpCacheDuration, ck) + s.cache.StoreError(ctx, err, httpCacheDuration, ck) return nil, err } @@ -270,7 +254,7 @@ func (s *HTTPAdapter) Get(ctx context.Context, scope string, query string, ignor ErrorString: err.Error(), Scope: scope, } - s.Cache().StoreError(ctx, err, httpCacheDuration, ck) + s.cache.StoreError(ctx, err, httpCacheDuration, ck) return nil, err } @@ -418,7 +402,7 @@ func (s *HTTPAdapter) Get(ctx context.Context, scope string, query string, ignor } } } - s.Cache().StoreItem(ctx, &item, httpCacheDuration, ck) + s.cache.StoreItem(ctx, &item, httpCacheDuration, ck) return &item, nil } diff --git a/stdlib-source/adapters/http_test.go b/stdlib-source/adapters/http_test.go index 88d90117..362b8ff0 100644 --- a/stdlib-source/adapters/http_test.go +++ b/stdlib-source/adapters/http_test.go @@ -13,6 +13,7 @@ import ( "github.com/overmindtech/cli/discovery" "github.com/overmindtech/cli/sdp-go" + "github.com/overmindtech/cli/sdpcache" ) const TestHTTPTimeout = 3 * time.Second @@ -115,7 +116,9 @@ func (t *TestHTTPServer) Close() { } func TestHTTPGet(t *testing.T) { - src := HTTPAdapter{} + src := HTTPAdapter{ + cache: sdpcache.NewNoOpCache(), + } server, err := NewTestServer() if err != nil { t.Fatal(err) @@ -433,7 +436,9 @@ func TestHTTPGet(t *testing.T) { } func TestHTTPSearch(t *testing.T) { - src := HTTPAdapter{} + src := HTTPAdapter{ + cache: sdpcache.NewNoOpCache(), + } server, err := NewTestServer() if err != nil { t.Fatal(err) diff --git a/stdlib-source/adapters/main.go b/stdlib-source/adapters/main.go index eed2da43..a85612eb 100644 --- a/stdlib-source/adapters/main.go +++ b/stdlib-source/adapters/main.go @@ -13,7 +13,6 @@ import ( "github.com/overmindtech/cli/sdp-go" "github.com/overmindtech/cli/sdpcache" "github.com/overmindtech/cli/stdlib-source/adapters/test" - log "github.com/sirupsen/logrus" _ "embed" ) @@ -23,14 +22,12 @@ var Metadata = sdp.AdapterMetadataList{} // Cache duration for RDAP adapters, these things shouldn't change very often const RdapCacheDuration = 30 * time.Minute -func InitializeEngine(ctx context.Context, ec *discovery.EngineConfig, reverseDNS bool) (*discovery.Engine, error) { - e, err := discovery.NewEngine(ec) - if err != nil { - log.WithFields(log.Fields{ - "error": err.Error(), - }).Fatal("Error initializing Engine") - } - +// InitializeAdapters adds stdlib adapters to an existing engine. This allows the engine +// to be created and serve health probes even if adapter initialization fails. +// +// Stdlib adapters rarely fail during initialization, but this pattern maintains consistency +// with other sources and allows for future error handling improvements. +func InitializeAdapters(ctx context.Context, e *discovery.Engine, reverseDNS bool) error { // Create a shared cache for all adapters in this source sharedCache := sdpcache.NewCache(ctx) @@ -79,9 +76,7 @@ func InitializeEngine(ctx context.Context, ec *discovery.EngineConfig, reverseDN // }, } - err = e.AddAdapters(adapters...) - - return e, err + return e.AddAdapters(adapters...) } // newRdapClient Creates a new RDAP client using otelhttp.DefaultClient. rdap is suspected to not be thread safe, so we create a new client for each request diff --git a/stdlib-source/adapters/rdap-asn_test.go b/stdlib-source/adapters/rdap-asn_test.go index 900f3e6e..51065960 100644 --- a/stdlib-source/adapters/rdap-asn_test.go +++ b/stdlib-source/adapters/rdap-asn_test.go @@ -13,7 +13,7 @@ func TestASNAdapterGet(t *testing.T) { src := &RdapASNAdapter{ ClientFac: func() *rdap.Client { return testRdapClient(t) }, - Cache: sdpcache.NewCache(t.Context()), + Cache: sdpcache.NewNoOpCache(), } item, err := src.Get(context.Background(), "global", "AS15169", false) diff --git a/stdlib-source/adapters/rdap-domain_test.go b/stdlib-source/adapters/rdap-domain_test.go index 1740cbdd..27fe878e 100644 --- a/stdlib-source/adapters/rdap-domain_test.go +++ b/stdlib-source/adapters/rdap-domain_test.go @@ -13,7 +13,7 @@ func TestDomainAdapterGet(t *testing.T) { src := &RdapDomainAdapter{ ClientFac: func() *rdap.Client { return testRdapClient(t) }, - Cache: sdpcache.NewCache(t.Context()), + Cache: sdpcache.NewNoOpCache(), } t.Run("without a dot", func(t *testing.T) { diff --git a/stdlib-source/adapters/rdap-entity_test.go b/stdlib-source/adapters/rdap-entity_test.go index b68ef8e7..eb90b64b 100644 --- a/stdlib-source/adapters/rdap-entity_test.go +++ b/stdlib-source/adapters/rdap-entity_test.go @@ -21,7 +21,7 @@ func TestEntityAdapterSearch(t *testing.T) { src := &RdapEntityAdapter{ ClientFac: func() *rdap.Client { return testRdapClient(t) }, - Cache: sdpcache.NewCache(t.Context()), + Cache: sdpcache.NewNoOpCache(), } for _, realUrl := range realUrls { diff --git a/stdlib-source/adapters/rdap-ip-network_test.go b/stdlib-source/adapters/rdap-ip-network_test.go index cdf898f0..b45b91af 100644 --- a/stdlib-source/adapters/rdap-ip-network_test.go +++ b/stdlib-source/adapters/rdap-ip-network_test.go @@ -13,7 +13,7 @@ func TestIpNetworkAdapterSearch(t *testing.T) { src := &RdapIPNetworkAdapter{ ClientFac: func() *rdap.Client { return testRdapClient(t) }, - Cache: sdpcache.NewCache(t.Context()), + Cache: sdpcache.NewMemoryCache(), IPCache: NewIPCache[*rdap.IPNetwork](), } diff --git a/stdlib-source/adapters/rdap-nameserver_test.go b/stdlib-source/adapters/rdap-nameserver_test.go index 151578e4..46882a5d 100644 --- a/stdlib-source/adapters/rdap-nameserver_test.go +++ b/stdlib-source/adapters/rdap-nameserver_test.go @@ -13,7 +13,7 @@ func TestNameserverAdapterSearch(t *testing.T) { src := &RdapNameserverAdapter{ ClientFac: func() *rdap.Client { return testRdapClient(t) }, - Cache: sdpcache.NewCache(t.Context()), + Cache: sdpcache.NewNoOpCache(), } items, err := src.Search(context.Background(), "global", "https://rdap.verisign.com/com/v1/nameserver/NS4.GOOGLE.COM", false) diff --git a/stdlib-source/cmd/root.go b/stdlib-source/cmd/root.go index 8b92e503..b65ecfe1 100644 --- a/stdlib-source/cmd/root.go +++ b/stdlib-source/cmd/root.go @@ -5,7 +5,6 @@ import ( "fmt" "os" "os/signal" - "strconv" "strings" "syscall" @@ -25,21 +24,22 @@ var cfgFile string // rootCmd represents the base command when called without any subcommands var rootCmd = &cobra.Command{ - Use: "stdlib-source", - Short: "Standard library of remotely accessible items", + Use: "stdlib-source", + Short: "Standard library of remotely accessible items", + SilenceUsage: true, Long: `Gets details of items that are globally scoped (usually) and able to be queried without authentication. `, - Run: func(cmd *cobra.Command, args []string) { - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - defer cancel() + RunE: func(cmd *cobra.Command, args []string) error { + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() defer tracing.LogRecoverToReturn(ctx, "stdlib-source.root") // get engine config engineConfig, err := discovery.EngineConfigFromViper("stdlib", tracing.Version()) if err != nil { - log.WithError(err).Fatal("Could not get engine config from viper") + log.WithError(err).Error("Could not get engine config from viper") + return fmt.Errorf("could not get engine config from viper: %w", err) } reverseDNS := viper.GetBool("reverse-dns") @@ -47,32 +47,18 @@ var rootCmd = &cobra.Command{ "reverse-dns": reverseDNS, }).Info("Got config") - // Validate the auth params and create a token client if we are using - // auth - err = engineConfig.CreateClients() + // Create a basic engine first + e, err := discovery.NewEngine(engineConfig) if err != nil { sentry.CaptureException(err) - log.WithError(err).Fatal("could not create auth clients") + log.WithError(err).Error("Could not create engine") + return fmt.Errorf("could not create engine: %w", err) } - e, err := adapters.InitializeEngine( - ctx, - engineConfig, - reverseDNS, - ) - if err != nil { - log.WithError(err).Error("Could not initialize aws source") - return - } - - // Start HTTP server for health checks - healthCheckPort := viper.GetString("service-port") - healthCheckPortInt, err := strconv.Atoi(healthCheckPort) - if err != nil { - log.WithError(err).WithFields(log.Fields{"service-port": healthCheckPort}).Fatal("Invalid service-port") - } + // Start HTTP server for health checks before initialization + healthCheckPort := viper.GetInt("health-check-port") - healthCheckDNSAdapter := adapters.DNSAdapter{} + healthCheckDNSAdapter := adapters.NewDNSAdapterForHealthCheck() // Set up health checks if e.EngineConfig.HeartbeatOptions == nil { @@ -89,43 +75,47 @@ var rootCmd = &cobra.Command{ return nil }) - e.ServeHealthProbes(healthCheckPortInt) + e.ServeHealthProbes(healthCheckPort) + // Start the engine (NATS connection) before adapter init so heartbeats work err = e.Start(ctx) if err != nil { - log.WithFields(log.Fields{ - "error": err, - }).Error("Could not start engine") - - os.Exit(1) + sentry.CaptureException(err) + log.WithError(err).Error("Could not start engine") + return fmt.Errorf("could not start engine: %w", err) } - sigs := make(chan os.Signal, 1) - - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + // Stdlib adapters are all in-memory (no external API calls), so no + // InitialiseAdapters retry wrapper needed — just use SetInitError on failure. + err = adapters.InitializeAdapters(ctx, e, reverseDNS) + if err != nil { + initErr := fmt.Errorf("could not initialize stdlib adapters: %w", err) + log.WithError(initErr).Error("Stdlib source initialization failed - pod will stay running with error status") + e.SetInitError(initErr) + sentry.CaptureException(initErr) + } else { + e.StartSendingHeartbeats(ctx) + } - <-sigs + <-ctx.Done() log.Info("Stopping engine") err = e.Stop() - if err != nil { - log.WithFields(log.Fields{ - "error": err, - }).Error("Could not stop engine") - - os.Exit(1) + log.WithError(err).Error("Could not stop engine") + return fmt.Errorf("could not stop engine: %w", err) } log.Info("Stopped") - os.Exit(0) + return nil }, } -// Execute adds all child commands to the root command and sets flags appropriately.add -// This is called by main.main(). It only needs to happen once to the rootCmd. +// Execute adds all child commands to the root command and sets flags +// appropriately. This is called by main.main(). It only needs to happen once to +// the rootCmd. func Execute() { if err := rootCmd.Execute(); err != nil { fmt.Println(err) @@ -150,8 +140,8 @@ func init() { // engine config options discovery.AddEngineFlags(rootCmd) - rootCmd.PersistentFlags().String("service-port", "8089", "the port to listen on") - cobra.CheckErr(viper.BindEnv("service-port", "STDLIB_SERVICE_PORT", "SERVICE_PORT")) // fallback to srcman config + rootCmd.PersistentFlags().IntP("health-check-port", "", 8089, "The port that the health check should run on") + cobra.CheckErr(viper.BindEnv("health-check-port", "STDLIB_HEALTH_CHECK_PORT", "HEALTH_CHECK_PORT", "STDLIB_SERVICE_PORT", "SERVICE_PORT")) // new names + backwards compat // tracing rootCmd.PersistentFlags().String("honeycomb-api-key", "", "If specified, configures opentelemetry libraries to submit traces to honeycomb") cobra.CheckErr(viper.BindEnv("honeycomb-api-key", "STDLIB_HONEYCOMB_API_KEY", "HONEYCOMB_API_KEY")) // fallback to global config @@ -162,15 +152,10 @@ func init() { cobra.CheckErr(viper.BindEnv("json-log", "STDLIB_SOURCE_JSON_LOG", "JSON_LOG")) // fallback to global config // Bind these to viper - err := viper.BindPFlags(rootCmd.PersistentFlags()) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - }).Fatal("Could not bind flags to viper") - } + cobra.CheckErr(viper.BindPFlags(rootCmd.PersistentFlags())) // Run this before we do anything to set up the loglevel - rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + rootCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { if lvl, err := log.ParseLevel(logLevel); err == nil { log.SetLevel(lvl) } else { @@ -183,25 +168,29 @@ func init() { log.AddHook(TerminationLogHook{}) // Bind flags that haven't been set to the values from viper of we have them + var bindErr error cmd.PersistentFlags().VisitAll(func(f *pflag.Flag) { // Bind the flag to viper only if it has a non-empty default if f.DefValue != "" || f.Changed { - err = viper.BindPFlag(f.Name, f) - if err != nil { - log.WithFields(log.Fields{ - "error": err, - }).Fatal("Could not bind flag to viper") + if err := viper.BindPFlag(f.Name, f); err != nil { + bindErr = err } } }) + if bindErr != nil { + log.WithError(bindErr).Error("Could not bind flag to viper") + return fmt.Errorf("could not bind flag to viper: %w", bindErr) + } if viper.GetBool("json-log") { logging.ConfigureLogrusJSON(log.StandardLogger()) } if err := tracing.InitTracerWithUpstreams("stdlib-source", viper.GetString("honeycomb-api-key"), viper.GetString("sentry-dsn")); err != nil { - log.Fatal(err) + log.WithError(err).Error("could not init tracer") + return fmt.Errorf("could not init tracer: %w", err) } + return nil } // shut down tracing at the end of the process @@ -234,8 +223,9 @@ func (t TerminationLogHook) Levels() []log.Level { } func (t TerminationLogHook) Fire(e *log.Entry) error { - tLog, err := os.OpenFile("/dev/termination-log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - + // shutdown tracing first to ensure all spans are flushed + tracing.ShutdownTracer(context.Background()) + tLog, err := os.OpenFile("/dev/termination-log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) if err != nil { return err } diff --git a/tfutils/plan_mapper.go b/tfutils/plan_mapper.go index 10b6e281..f0b2311c 100644 --- a/tfutils/plan_mapper.go +++ b/tfutils/plan_mapper.go @@ -33,6 +33,8 @@ func (m MapStatus) String() string { return "not enough info" case MapStatusUnsupported: return "unsupported" + case MapStatusPendingCreation: + return "pending creation" default: return "unknown" } @@ -42,6 +44,7 @@ const ( MapStatusSuccess MapStatus = iota MapStatusNotEnoughInfo MapStatusUnsupported + MapStatusPendingCreation ) const KnownAfterApply = `(known after apply)` @@ -80,6 +83,10 @@ func (r *PlanMappingResult) NumUnsupported() int { return r.numStatus(MapStatusUnsupported) } +func (r *PlanMappingResult) NumPendingCreation() int { + return r.numStatus(MapStatusPendingCreation) +} + func (r *PlanMappingResult) NumTotal() int { return len(r.Results) } @@ -250,7 +257,7 @@ func MappedItemDiffsFromPlan(ctx context.Context, planJson []byte, fileName stri // Attach failed mappings to the span for _, result := range results.Results { switch result.Status { - case MapStatusUnsupported, MapStatusNotEnoughInfo: + case MapStatusUnsupported, MapStatusNotEnoughInfo, MapStatusPendingCreation: span.AddEvent("UnmappedResource", trace.WithAttributes( attribute.String("ovm.climap.status", result.Status.String()), attribute.String("ovm.climap.message", result.Message), @@ -274,14 +281,16 @@ func mapResourceToQuery(itemDiff *sdp.ItemDiff, terraformResource *Resource, map attemptedMappings := make([]string, 0) if len(mappings) == 0 { + mappingStatus := sdp.MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_UNSUPPORTED return PlannedChangeMapResult{ TerraformName: terraformResource.Address, TerraformType: terraformResource.Type, Status: MapStatusUnsupported, Message: "unsupported", MappedItemDiff: &sdp.MappedItemDiff{ - Item: itemDiff, - MappingQuery: nil, // unmapped item has no mapping query + Item: itemDiff, + MappingQuery: nil, // unmapped item has no mapping query + MappingStatus: &mappingStatus, }, } } @@ -312,14 +321,16 @@ func mapResourceToQuery(itemDiff *sdp.ItemDiff, terraformResource *Resource, map itemDiff.After.Type = mapping.OvermindType } + mappingStatus := sdp.MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_SUCCESS return PlannedChangeMapResult{ TerraformName: terraformResource.Address, TerraformType: terraformResource.Type, Status: MapStatusSuccess, Message: "mapped", MappedItemDiff: &sdp.MappedItemDiff{ - Item: itemDiff, - MappingQuery: newQuery, + Item: itemDiff, + MappingQuery: newQuery, + MappingStatus: &mappingStatus, }, } } @@ -331,14 +342,36 @@ func mapResourceToQuery(itemDiff *sdp.ItemDiff, terraformResource *Resource, map // If we get to this point, we haven't found a mapping message := fmt.Sprintf("missing mapping attribute: %v", strings.Join(attemptedMappings, ", ")) + + // Check if this is a newly created resource - these don't exist yet so missing + // attributes are expected, not an error + if itemDiff.GetStatus() == sdp.ItemDiffStatus_ITEM_DIFF_STATUS_CREATED { + mappingStatus := sdp.MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_PENDING_CREATION + return PlannedChangeMapResult{ + TerraformName: terraformResource.Address, + TerraformType: terraformResource.Type, + Status: MapStatusPendingCreation, + Message: "pending creation", + MappedItemDiff: &sdp.MappedItemDiff{ + Item: itemDiff, + MappingQuery: nil, // unmapped item has no mapping query + MappingStatus: &mappingStatus, + // No MappingError - this is expected, not an error + }, + } + } + + // For other statuses (REPLACED, UPDATED, DELETED), missing attributes are a real error + mappingStatus := sdp.MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_ERROR return PlannedChangeMapResult{ TerraformName: terraformResource.Address, TerraformType: terraformResource.Type, Status: MapStatusNotEnoughInfo, Message: message, MappedItemDiff: &sdp.MappedItemDiff{ - Item: itemDiff, - MappingQuery: nil, // unmapped item has no mapping query + Item: itemDiff, + MappingQuery: nil, // unmapped item has no mapping query + MappingStatus: &mappingStatus, MappingError: &sdp.QueryError{ ErrorType: sdp.QueryError_OTHER, ErrorString: message, diff --git a/tfutils/plan_mapper_test.go b/tfutils/plan_mapper_test.go index c4715acd..2e93aa20 100644 --- a/tfutils/plan_mapper_test.go +++ b/tfutils/plan_mapper_test.go @@ -23,6 +23,117 @@ func TestWithStateFile(t *testing.T) { } } +func TestMapResourceToQuery_PendingCreation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + itemDiffStatus sdp.ItemDiffStatus + hasMappings bool + expectedMapStatus MapStatus + expectedMappingStatus sdp.MappedItemMappingStatus + expectMappingError bool + }{ + { + name: "CREATED with missing attributes - pending creation", + itemDiffStatus: sdp.ItemDiffStatus_ITEM_DIFF_STATUS_CREATED, + hasMappings: true, + expectedMapStatus: MapStatusPendingCreation, + expectedMappingStatus: sdp.MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_PENDING_CREATION, + expectMappingError: false, + }, + { + name: "UPDATED with missing attributes - error", + itemDiffStatus: sdp.ItemDiffStatus_ITEM_DIFF_STATUS_UPDATED, + hasMappings: true, + expectedMapStatus: MapStatusNotEnoughInfo, + expectedMappingStatus: sdp.MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_ERROR, + expectMappingError: true, + }, + { + name: "DELETED with missing attributes - error", + itemDiffStatus: sdp.ItemDiffStatus_ITEM_DIFF_STATUS_DELETED, + hasMappings: true, + expectedMapStatus: MapStatusNotEnoughInfo, + expectedMappingStatus: sdp.MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_ERROR, + expectMappingError: true, + }, + { + name: "REPLACED with missing attributes - error", + itemDiffStatus: sdp.ItemDiffStatus_ITEM_DIFF_STATUS_REPLACED, + hasMappings: true, + expectedMapStatus: MapStatusNotEnoughInfo, + expectedMappingStatus: sdp.MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_ERROR, + expectMappingError: true, + }, + { + name: "No mappings - unsupported", + itemDiffStatus: sdp.ItemDiffStatus_ITEM_DIFF_STATUS_CREATED, + hasMappings: false, + expectedMapStatus: MapStatusUnsupported, + expectedMappingStatus: sdp.MappedItemMappingStatus_MAPPED_ITEM_MAPPING_STATUS_UNSUPPORTED, + expectMappingError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create an itemDiff with the specified status + itemDiff := &sdp.ItemDiff{ + Status: tt.itemDiffStatus, + } + + // Create a terraform resource without the mapping attribute (simulating missing id/arn) + terraformResource := &Resource{ + Address: "test_resource.example", + Type: "test_resource", + AttributeValues: AttributeValues{ + // No "id" field - simulating missing mapping attribute + "name": "test-name", + }, + } + + // Setup mappings - empty if testing unsupported, otherwise include one + var mappings []TfMapData + if tt.hasMappings { + mappings = []TfMapData{ + { + OvermindType: "test-type", + Method: sdp.QueryMethod_GET, + QueryField: "id", // This field doesn't exist in AttributeValues + }, + } + } + + // Call the function + result := mapResourceToQuery(itemDiff, terraformResource, mappings) + + // Verify the MapStatus + if result.Status != tt.expectedMapStatus { + t.Errorf("Expected MapStatus %v, got %v", tt.expectedMapStatus, result.Status) + } + + // Verify the MappingStatus + if result.MappedItemDiff.GetMappingStatus() != tt.expectedMappingStatus { + t.Errorf("Expected MappingStatus %v, got %v", tt.expectedMappingStatus, result.MappedItemDiff.GetMappingStatus()) + } + + // Verify MappingError presence + if tt.expectMappingError && result.MappedItemDiff.GetMappingError() == nil { + t.Error("Expected MappingError to be set, but it was nil") + } + if !tt.expectMappingError && result.MappedItemDiff.GetMappingError() != nil { + t.Errorf("Expected MappingError to be nil, but got: %v", result.MappedItemDiff.GetMappingError()) + } + + // Verify MappingQuery is nil (no query should be created when mapping fails) + if result.MappedItemDiff.GetMappingQuery() != nil { + t.Errorf("Expected MappingQuery to be nil, but got: %v", result.MappedItemDiff.GetMappingQuery()) + } + }) + } +} + func TestExtractProviderNameFromConfigKey(t *testing.T) { tests := []struct { ConfigKey string @@ -317,6 +428,9 @@ func TestPlanMappingResultNumFuncs(t *testing.T) { { Status: MapStatusUnsupported, }, + { + Status: MapStatusPendingCreation, + }, }, } @@ -331,6 +445,16 @@ func TestPlanMappingResultNumFuncs(t *testing.T) { if result.NumUnsupported() != 1 { t.Errorf("Expected 1 unsupported, got %v", result.NumUnsupported()) } + + if result.NumPendingCreation() != 1 { + t.Errorf("Expected 1 pending creation, got %v", result.NumPendingCreation()) + } + + // Sum of individual counts should equal NumTotal + sum := result.NumSuccess() + result.NumNotEnoughInfo() + result.NumUnsupported() + result.NumPendingCreation() + if sum != result.NumTotal() { + t.Errorf("Sum of status counts (%v) should equal NumTotal (%v)", sum, result.NumTotal()) + } } func TestInterpolateScope(t *testing.T) {