From 67aaf007d17b603827a6b1d6ac94d9ccd855f2ff Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 13 Mar 2026 00:25:31 +0100 Subject: [PATCH 1/2] Add limit iterator and context plumbing for --limit flag --- libs/cmdio/limit.go | 50 +++++++++++++ libs/cmdio/limit_test.go | 146 ++++++++++++++++++++++++++++++++++++++ libs/cmdio/render.go | 3 + libs/cmdio/render_test.go | 30 ++++++++ 4 files changed, 229 insertions(+) create mode 100644 libs/cmdio/limit.go create mode 100644 libs/cmdio/limit_test.go diff --git a/libs/cmdio/limit.go b/libs/cmdio/limit.go new file mode 100644 index 0000000000..d6ef653a45 --- /dev/null +++ b/libs/cmdio/limit.go @@ -0,0 +1,50 @@ +package cmdio + +import ( + "context" + + "github.com/databricks/databricks-sdk-go/listing" +) + +type limitKey struct{} + +// WithLimit stores the limit in the context. +func WithLimit(ctx context.Context, n int) context.Context { + return context.WithValue(ctx, limitKey{}, n) +} + +// GetLimit retrieves the limit from context. Returns 0 if not set. +func GetLimit(ctx context.Context) int { + v, ok := ctx.Value(limitKey{}).(int) + if !ok { + return 0 + } + return v +} + +type limitIterator[T any] struct { + inner listing.Iterator[T] + remaining int +} + +func (l *limitIterator[T]) HasNext(ctx context.Context) bool { + return l.remaining > 0 && l.inner.HasNext(ctx) +} + +func (l *limitIterator[T]) Next(ctx context.Context) (T, error) { + v, err := l.inner.Next(ctx) + if err != nil { + return v, err + } + l.remaining-- + return v, nil +} + +// ApplyLimit wraps a listing.Iterator to yield at most the limit from context. +// It returns the iterator unchanged if the limit is not positive. +func ApplyLimit[T any](ctx context.Context, i listing.Iterator[T]) listing.Iterator[T] { + if limit := GetLimit(ctx); limit > 0 { + return &limitIterator[T]{inner: i, remaining: limit} + } + return i +} diff --git a/libs/cmdio/limit_test.go b/libs/cmdio/limit_test.go new file mode 100644 index 0000000000..33ec7e6401 --- /dev/null +++ b/libs/cmdio/limit_test.go @@ -0,0 +1,146 @@ +package cmdio_test + +import ( + "context" + "errors" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/listing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type sliceIterator[T any] struct { + items []T +} + +func (s *sliceIterator[T]) HasNext(_ context.Context) bool { + return len(s.items) > 0 +} + +func (s *sliceIterator[T]) Next(_ context.Context) (T, error) { + if len(s.items) == 0 { + var zero T + return zero, errors.New("no more items") + } + item := s.items[0] + s.items = s.items[1:] + return item, nil +} + +func drain[T any](ctx context.Context, iter listing.Iterator[T]) ([]T, error) { + var result []T + for iter.HasNext(ctx) { + v, err := iter.Next(ctx) + if err != nil { + return result, err + } + result = append(result, v) + } + return result, nil +} + +type errorIterator[T any] struct { + items []T + failAt int + callCount int +} + +func (e *errorIterator[T]) HasNext(_ context.Context) bool { + return e.callCount <= e.failAt && e.callCount < len(e.items) +} + +func (e *errorIterator[T]) Next(_ context.Context) (T, error) { + idx := e.callCount + e.callCount++ + if idx == e.failAt { + var zero T + return zero, errors.New("fetch error") + } + return e.items[idx], nil +} + +func TestWithLimitRoundTrip(t *testing.T) { + ctx := cmdio.WithLimit(t.Context(), 42) + assert.Equal(t, 42, cmdio.GetLimit(ctx)) +} + +func TestGetLimitReturnsZeroWhenNotSet(t *testing.T) { + assert.Equal(t, 0, cmdio.GetLimit(t.Context())) +} + +func TestApplyLimit(t *testing.T) { + tests := []struct { + name string + limit int + setLimit bool + items []int + want []int + }{ + { + name: "caps results", + limit: 5, + setLimit: true, + items: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + want: []int{1, 2, 3, 4, 5}, + }, + { + name: "no-op when unset", + items: []int{1, 2, 3}, + want: []int{1, 2, 3}, + }, + { + name: "greater than total", + limit: 10, + setLimit: true, + items: []int{1, 2, 3}, + want: []int{1, 2, 3}, + }, + { + name: "one", + limit: 1, + setLimit: true, + items: []int{1, 2, 3}, + want: []int{1}, + }, + { + name: "zero", + limit: 0, + setLimit: true, + items: []int{1, 2, 3}, + want: []int{1, 2, 3}, + }, + { + name: "negative", + limit: -1, + setLimit: true, + items: []int{1, 2, 3}, + want: []int{1, 2, 3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := t.Context() + if tt.setLimit { + ctx = cmdio.WithLimit(ctx, tt.limit) + } + + iter := cmdio.ApplyLimit(ctx, &sliceIterator[int]{items: tt.items}) + + result, err := drain(t.Context(), iter) + require.NoError(t, err) + assert.Equal(t, tt.want, result) + }) + } +} + +func TestApplyLimitPreservesErrors(t *testing.T) { + ctx := cmdio.WithLimit(t.Context(), 5) + iter := cmdio.ApplyLimit(ctx, &errorIterator[int]{items: []int{1, 2, 3}, failAt: 2}) + + result, err := drain(t.Context(), iter) + assert.ErrorContains(t, err, "fetch error") + assert.Equal(t, []int{1, 2}, result) +} diff --git a/libs/cmdio/render.go b/libs/cmdio/render.go index c344c3d028..2bd52cc64d 100644 --- a/libs/cmdio/render.go +++ b/libs/cmdio/render.go @@ -264,6 +264,7 @@ func Render(ctx context.Context, v any) error { } func RenderIterator[T any](ctx context.Context, i listing.Iterator[T]) error { + i = ApplyLimit(ctx, i) c := fromContext(ctx) return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, c.headerTemplate, c.template) } @@ -277,11 +278,13 @@ func RenderWithTemplate(ctx context.Context, v any, headerTemplate, template str } func RenderIteratorWithTemplate[T any](ctx context.Context, i listing.Iterator[T], headerTemplate, template string) error { + i = ApplyLimit(ctx, i) c := fromContext(ctx) return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, headerTemplate, template) } func RenderIteratorJson[T any](ctx context.Context, i listing.Iterator[T]) error { + i = ApplyLimit(ctx, i) c := fromContext(ctx) return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, c.headerTemplate, c.template) } diff --git a/libs/cmdio/render_test.go b/libs/cmdio/render_test.go index be41f80c38..67440b6b2d 100644 --- a/libs/cmdio/render_test.go +++ b/libs/cmdio/render_test.go @@ -167,6 +167,36 @@ var testCases = []testCase{ }, } +func TestRenderIteratorWithLimit(t *testing.T) { + output := &bytes.Buffer{} + ctx := t.Context() + cmdIO := NewIO(ctx, flags.OutputText, nil, output, output, + "id\tname", + "{{range .}}{{.WorkspaceId}}\t{{.WorkspaceName}}\n{{end}}") + ctx = InContext(ctx, cmdIO) + ctx = WithLimit(ctx, 3) + + err := RenderIterator(ctx, makeIterator(10)) + assert.NoError(t, err) + assert.Equal(t, "id name\n"+makeBigOutput(3), output.String()) +} + +func TestRenderIteratorWithLimitJSON(t *testing.T) { + output := &bytes.Buffer{} + ctx := t.Context() + cmdIO := NewIO(ctx, flags.OutputJSON, nil, output, output, "", "") + ctx = InContext(ctx, cmdIO) + ctx = WithLimit(ctx, 2) + + err := RenderIterator(ctx, makeIterator(10)) + assert.NoError(t, err) + + var items []provisioning.Workspace + err = json.Unmarshal(output.Bytes(), &items) + assert.NoError(t, err) + assert.Len(t, items, 2) +} + func TestRender(t *testing.T) { for _, c := range testCases { t.Run(c.name, func(t *testing.T) { From 8d90102f94775f6e79d13e23f5aa0f6b35204555 Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 13 Mar 2026 15:39:54 +0100 Subject: [PATCH 2/2] Add defensive guard in limitIterator.Next() for exhausted limit Return listing.ErrNoMoreItems when Next() is called with remaining <= 0, so the limit is enforced even if the caller skips HasNext(). --- libs/cmdio/limit.go | 4 ++++ libs/cmdio/limit_test.go | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/libs/cmdio/limit.go b/libs/cmdio/limit.go index d6ef653a45..8bcf84151b 100644 --- a/libs/cmdio/limit.go +++ b/libs/cmdio/limit.go @@ -32,6 +32,10 @@ func (l *limitIterator[T]) HasNext(ctx context.Context) bool { } func (l *limitIterator[T]) Next(ctx context.Context) (T, error) { + if l.remaining <= 0 { + var zero T + return zero, listing.ErrNoMoreItems + } v, err := l.inner.Next(ctx) if err != nil { return v, err diff --git a/libs/cmdio/limit_test.go b/libs/cmdio/limit_test.go index 33ec7e6401..2165a17da2 100644 --- a/libs/cmdio/limit_test.go +++ b/libs/cmdio/limit_test.go @@ -144,3 +144,21 @@ func TestApplyLimitPreservesErrors(t *testing.T) { assert.ErrorContains(t, err, "fetch error") assert.Equal(t, []int{1, 2}, result) } + +func TestLimitIteratorNextWithoutHasNextReturnsError(t *testing.T) { + ctx := cmdio.WithLimit(t.Context(), 2) + iter := cmdio.ApplyLimit(ctx, &sliceIterator[int]{items: []int{1, 2, 3, 4, 5}}) + + // Drain the allowed items. + v1, err := iter.Next(t.Context()) + require.NoError(t, err) + assert.Equal(t, 1, v1) + + v2, err := iter.Next(t.Context()) + require.NoError(t, err) + assert.Equal(t, 2, v2) + + // Calling Next() again without HasNext() must return ErrNoMoreItems. + _, err = iter.Next(t.Context()) + assert.ErrorIs(t, err, listing.ErrNoMoreItems) +}