Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions sei-db/common/iterators/domain_iterator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package iterators

import (
"fmt"

dbm "github.com/tendermint/tm-db"
)

var _ dbm.Iterator = (*domainIterator)(nil)

// domainIterator wraps a parent iterator and overrides Domain() to report a
// caller-supplied [start, end) range. It is useful when the parent is built
// over a physical/translated keyspace (so its own Domain() reflects physical
// bounds) but callers expect the logical bounds they requested, as required by
// the dbm.Iterator contract. All other methods are inherited from the parent.
type domainIterator struct {
dbm.Iterator
start []byte
end []byte
}

// NewDomainIterator returns an iterator that behaves exactly like parent except
// that Domain() reports [start, end). The parent must be non-nil.
func NewDomainIterator(parent dbm.Iterator, start, end []byte) (dbm.Iterator, error) {
if parent == nil {
return nil, fmt.Errorf("nil parent iterator")
}
return &domainIterator{Iterator: parent, start: start, end: end}, nil
}

func (d *domainIterator) Domain() ([]byte, []byte) {
return d.start, d.end
}
63 changes: 63 additions & 0 deletions sei-db/common/iterators/domain_iterator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package iterators_test

import (
"testing"

"github.com/sei-protocol/sei-chain/sei-db/common/iterators"
"github.com/stretchr/testify/require"
)

func TestNewDomainIterator_NilParent(t *testing.T) {
it, err := iterators.NewDomainIterator(nil, []byte("a"), []byte("z"))
require.Error(t, err)
require.Nil(t, it)
}

func TestNewDomainIterator_OverridesDomain(t *testing.T) {
data := map[string][]byte{
"b": []byte("vb"),
"a": []byte("va"),
"c": []byte("vc"),
}
parent, err := iterators.NewMapIterator(nil, nil, true, iterators.BytesSerializer, data)
require.NoError(t, err)

// Sanity check: the parent reports nil bounds before wrapping.
pStart, pEnd := parent.Domain()
require.Nil(t, pStart)
require.Nil(t, pEnd)

start, end := []byte("a"), []byte("d")
it, err := iterators.NewDomainIterator(parent, start, end)
require.NoError(t, err)
defer it.Close()

gotStart, gotEnd := it.Domain()
require.Equal(t, start, gotStart)
require.Equal(t, end, gotEnd)
}

func TestNewDomainIterator_DelegatesIteration(t *testing.T) {
data := map[string][]byte{
"a": []byte("va"),
"b": []byte("vb"),
"c": []byte("vc"),
}
parent, err := iterators.NewMapIterator(nil, nil, true, iterators.BytesSerializer, data)
require.NoError(t, err)

it, err := iterators.NewDomainIterator(parent, []byte("a"), []byte("d"))
require.NoError(t, err)
defer it.Close()

var got [][2][]byte
for ; it.Valid(); it.Next() {
got = append(got, [2][]byte{it.Key(), it.Value()})
}
require.NoError(t, it.Error())
require.Equal(t, [][2][]byte{
{[]byte("a"), []byte("va")},
{[]byte("b"), []byte("vb")},
{[]byte("c"), []byte("vc")},
}, got)
}
160 changes: 160 additions & 0 deletions sei-db/common/iterators/map_iterator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package iterators

import (
"bytes"
"fmt"
"sort"

"github.com/sei-protocol/sei-chain/sei-db/common/utils"
dbm "github.com/tendermint/tm-db"
)

var _ dbm.Iterator = (*mapIterator[any])(nil)

// Iterates over a map of key/value pairs.
type mapIterator[T any] struct {
kvPairs []kvPair
currentIndex int
start []byte
end []byte
}

type kvPair struct {
key []byte
value []byte
}

// BytesSerializer is a pass-through serializer for map[string][]byte.
func BytesSerializer(v []byte) ([]byte, error) {
return v, nil
}

// NewMapIterator returns an iterator over the union of maps in lexicographic order
// (or reverse lex order when ascending is false). start is inclusive; end is
// exclusive. nil start or end means unbounded on that side. Duplicate keys across
// maps are rejected. Values are serialized with serializer before iteration.
func NewMapIterator[T any](
start []byte,
end []byte,
ascending bool,
serializer func(T) ([]byte, error),
maps ...map[string]T,
) (dbm.Iterator, error) {
if serializer == nil {
return nil, fmt.Errorf("nil serializer")
}
pairs, err := buildMapPairs(start, end, ascending, serializer, maps...)
if err != nil {
return nil, err
}
return &mapIterator[T]{
kvPairs: pairs,
start: start,
end: end,
}, nil
}

func buildMapPairs[T any](
start, end []byte,
ascending bool,
serializer func(T) ([]byte, error),
maps ...map[string]T,
) ([]kvPair, error) {
if start != nil && end != nil && bytes.Compare(start, end) > 0 {
return nil, nil
}

total := 0
for _, data := range maps {
total += len(data)
}
if total == 0 {
return nil, nil
}

seen := make(map[string]struct{}, total)
pairs := make([]kvPair, 0, total)
for _, data := range maps {
if data == nil {
continue
}
for k, v := range data {
if _, dup := seen[k]; dup {
return nil, fmt.Errorf("duplicate key %q", k)
}
seen[k] = struct{}{}

key := []byte(k)
if !keyInRange(key, start, end) {
continue
}

serialized, err := serializer(v)
if err != nil {
return nil, fmt.Errorf("serialize key %q: %w", k, err)
}
pairs = append(pairs, kvPair{
key: utils.Clone(key),
value: utils.Clone(serialized),
})
}
}

sort.Slice(pairs, func(i, j int) bool {
cmp := bytes.Compare(pairs[i].key, pairs[j].key)
if ascending {
return cmp < 0
}
return cmp > 0
})
return pairs, nil
}

func keyInRange(key, start, end []byte) bool {
if start != nil && bytes.Compare(key, start) < 0 {
return false
}
if end != nil && bytes.Compare(key, end) >= 0 {
return false
}
return true
}

func (m *mapIterator[T]) Close() error {
m.kvPairs = nil
m.currentIndex = 0
return nil
}

func (m *mapIterator[T]) Domain() ([]byte, []byte) {
return m.start, m.end
}

func (m *mapIterator[T]) Error() error {
return nil
}

func (m *mapIterator[T]) Key() []byte {
if !m.Valid() {
return nil
}
return m.kvPairs[m.currentIndex].key
}

func (m *mapIterator[T]) Next() {
if !m.Valid() {
return
}
m.currentIndex++
}

func (m *mapIterator[T]) Valid() bool {
return m.currentIndex >= 0 && m.currentIndex < len(m.kvPairs)
}

func (m *mapIterator[T]) Value() []byte {
if !m.Valid() {
return nil
}
return m.kvPairs[m.currentIndex].value
}
Loading
Loading