Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions libs/cmdio/limit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package cmdio

import (
"context"

"github.com/databricks/databricks-sdk-go/listing"
)

type limitKey struct{}

// WithLimit stores the limit in the context.
func WithLimit(ctx context.Context, n int) context.Context {
return context.WithValue(ctx, limitKey{}, n)
}

// GetLimit retrieves the limit from context. Returns 0 if not set.
func GetLimit(ctx context.Context) int {
v, ok := ctx.Value(limitKey{}).(int)
if !ok {
return 0
}
return v
}

type limitIterator[T any] struct {
inner listing.Iterator[T]
remaining int
}

func (l *limitIterator[T]) HasNext(ctx context.Context) bool {
return l.remaining > 0 && l.inner.HasNext(ctx)
}

func (l *limitIterator[T]) Next(ctx context.Context) (T, error) {
if l.remaining <= 0 {
var zero T
return zero, listing.ErrNoMoreItems
}
v, err := l.inner.Next(ctx)
if err != nil {
return v, err
}
l.remaining--
return v, nil
}

// ApplyLimit wraps a listing.Iterator to yield at most the limit from context.
// It returns the iterator unchanged if the limit is not positive.
func ApplyLimit[T any](ctx context.Context, i listing.Iterator[T]) listing.Iterator[T] {
if limit := GetLimit(ctx); limit > 0 {
return &limitIterator[T]{inner: i, remaining: limit}
}
return i
}
164 changes: 164 additions & 0 deletions libs/cmdio/limit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package cmdio_test

import (
"context"
"errors"
"testing"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go/listing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type sliceIterator[T any] struct {
items []T
}

func (s *sliceIterator[T]) HasNext(_ context.Context) bool {
return len(s.items) > 0
}

func (s *sliceIterator[T]) Next(_ context.Context) (T, error) {
if len(s.items) == 0 {
var zero T
return zero, errors.New("no more items")
}
item := s.items[0]
s.items = s.items[1:]
return item, nil
}

func drain[T any](ctx context.Context, iter listing.Iterator[T]) ([]T, error) {
var result []T
for iter.HasNext(ctx) {
v, err := iter.Next(ctx)
if err != nil {
return result, err
}
result = append(result, v)
}
return result, nil
}

type errorIterator[T any] struct {
items []T
failAt int
callCount int
}

func (e *errorIterator[T]) HasNext(_ context.Context) bool {
return e.callCount <= e.failAt && e.callCount < len(e.items)
}

func (e *errorIterator[T]) Next(_ context.Context) (T, error) {
idx := e.callCount
e.callCount++
if idx == e.failAt {
var zero T
return zero, errors.New("fetch error")
}
return e.items[idx], nil
}

func TestWithLimitRoundTrip(t *testing.T) {
ctx := cmdio.WithLimit(t.Context(), 42)
assert.Equal(t, 42, cmdio.GetLimit(ctx))
}

func TestGetLimitReturnsZeroWhenNotSet(t *testing.T) {
assert.Equal(t, 0, cmdio.GetLimit(t.Context()))
}

func TestApplyLimit(t *testing.T) {
tests := []struct {
name string
limit int
setLimit bool
items []int
want []int
}{
{
name: "caps results",
limit: 5,
setLimit: true,
items: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
want: []int{1, 2, 3, 4, 5},
},
{
name: "no-op when unset",
items: []int{1, 2, 3},
want: []int{1, 2, 3},
},
{
name: "greater than total",
limit: 10,
setLimit: true,
items: []int{1, 2, 3},
want: []int{1, 2, 3},
},
{
name: "one",
limit: 1,
setLimit: true,
items: []int{1, 2, 3},
want: []int{1},
},
{
name: "zero",
limit: 0,
setLimit: true,
items: []int{1, 2, 3},
want: []int{1, 2, 3},
},
{
name: "negative",
limit: -1,
setLimit: true,
items: []int{1, 2, 3},
want: []int{1, 2, 3},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := t.Context()
if tt.setLimit {
ctx = cmdio.WithLimit(ctx, tt.limit)
}

iter := cmdio.ApplyLimit(ctx, &sliceIterator[int]{items: tt.items})

result, err := drain(t.Context(), iter)
require.NoError(t, err)
assert.Equal(t, tt.want, result)
})
}
}

func TestApplyLimitPreservesErrors(t *testing.T) {
ctx := cmdio.WithLimit(t.Context(), 5)
iter := cmdio.ApplyLimit(ctx, &errorIterator[int]{items: []int{1, 2, 3}, failAt: 2})

result, err := drain(t.Context(), iter)
assert.ErrorContains(t, err, "fetch error")
assert.Equal(t, []int{1, 2}, result)
}

func TestLimitIteratorNextWithoutHasNextReturnsError(t *testing.T) {
ctx := cmdio.WithLimit(t.Context(), 2)
iter := cmdio.ApplyLimit(ctx, &sliceIterator[int]{items: []int{1, 2, 3, 4, 5}})

// Drain the allowed items.
v1, err := iter.Next(t.Context())
require.NoError(t, err)
assert.Equal(t, 1, v1)

v2, err := iter.Next(t.Context())
require.NoError(t, err)
assert.Equal(t, 2, v2)

// Calling Next() again without HasNext() must return ErrNoMoreItems.
_, err = iter.Next(t.Context())
assert.ErrorIs(t, err, listing.ErrNoMoreItems)
}
3 changes: 3 additions & 0 deletions libs/cmdio/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ func Render(ctx context.Context, v any) error {
}

func RenderIterator[T any](ctx context.Context, i listing.Iterator[T]) error {
i = ApplyLimit(ctx, i)
c := fromContext(ctx)
return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, c.headerTemplate, c.template)
}
Expand All @@ -277,11 +278,13 @@ func RenderWithTemplate(ctx context.Context, v any, headerTemplate, template str
}

func RenderIteratorWithTemplate[T any](ctx context.Context, i listing.Iterator[T], headerTemplate, template string) error {
i = ApplyLimit(ctx, i)
c := fromContext(ctx)
return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, headerTemplate, template)
}

func RenderIteratorJson[T any](ctx context.Context, i listing.Iterator[T]) error {
i = ApplyLimit(ctx, i)
c := fromContext(ctx)
return renderWithTemplate(ctx, newIteratorRenderer(i), c.outputFormat, c.out, c.headerTemplate, c.template)
}
Expand Down
30 changes: 30 additions & 0 deletions libs/cmdio/render_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,36 @@ var testCases = []testCase{
},
}

func TestRenderIteratorWithLimit(t *testing.T) {
output := &bytes.Buffer{}
ctx := t.Context()
cmdIO := NewIO(ctx, flags.OutputText, nil, output, output,
"id\tname",
"{{range .}}{{.WorkspaceId}}\t{{.WorkspaceName}}\n{{end}}")
ctx = InContext(ctx, cmdIO)
ctx = WithLimit(ctx, 3)

err := RenderIterator(ctx, makeIterator(10))
assert.NoError(t, err)
assert.Equal(t, "id name\n"+makeBigOutput(3), output.String())
}

func TestRenderIteratorWithLimitJSON(t *testing.T) {
output := &bytes.Buffer{}
ctx := t.Context()
cmdIO := NewIO(ctx, flags.OutputJSON, nil, output, output, "", "")
ctx = InContext(ctx, cmdIO)
ctx = WithLimit(ctx, 2)

err := RenderIterator(ctx, makeIterator(10))
assert.NoError(t, err)

var items []provisioning.Workspace
err = json.Unmarshal(output.Bytes(), &items)
assert.NoError(t, err)
assert.Len(t, items, 2)
}

func TestRender(t *testing.T) {
for _, c := range testCases {
t.Run(c.name, func(t *testing.T) {
Expand Down
Loading