diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index be8c89b5db..4010586d47 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -175,6 +175,21 @@ jobs: done unset IFS # revert the internal field separator back to default + - name: Print node logs on failure + if: ${{ failure() }} + run: | + set -euo pipefail + for c in sei-node-0 sei-node-1 sei-node-2 sei-node-3; do + echo "==================== ${c} (docker logs tail) ====================" + docker logs --tail 200 "${c}" || true + echo "==================== ${c} (seid log file tail) ====================" + # Logs are accessible on host since build/generated is mounted in containers + NODE_ID=${c#sei-node-} + if [ -f "build/generated/logs/seid-${NODE_ID}.log" ]; then + tail -200 "build/generated/logs/seid-${NODE_ID}.log" || true + fi + done + - name: Prepare log artifact name if: ${{ always() }} id: log_artifact_meta diff --git a/go.mod b/go.mod index 0f278af69b..cb9b9b5967 100644 --- a/go.mod +++ b/go.mod @@ -50,7 +50,7 @@ require ( github.com/tendermint/tm-db v0.6.8-0.20220519162814-e24b96538a12 github.com/tidwall/btree v1.6.0 github.com/tidwall/gjson v1.10.2 - github.com/tidwall/wal v1.1.7 + github.com/tidwall/wal v1.2.1 github.com/zbiljic/go-filelock v0.0.0-20170914061330-1dbf7103ab7d github.com/zeebo/blake3 v0.2.4 go.opentelemetry.io/otel v1.38.0 diff --git a/go.sum b/go.sum index 2d8b692fab..403fca7e7b 100644 --- a/go.sum +++ b/go.sum @@ -2083,8 +2083,8 @@ github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhso github.com/tidwall/sjson v1.1.4/go.mod h1:wXpKXu8CtDjKAZ+3DrKY5ROCorDFahq8l0tey/Lx1fg= github.com/tidwall/tinylru v1.1.0 h1:XY6IUfzVTU9rpwdhKUF6nQdChgCdGjkMfLzbWyiau6I= github.com/tidwall/tinylru v1.1.0/go.mod h1:3+bX+TJ2baOLMWTnlyNWHh4QMnFyARg2TLTQ6OFbzw8= -github.com/tidwall/wal v1.1.7 h1:emc1TRjIVsdKKSnpwGBAcsAGg0767SvUk8+ygx7Bb+4= -github.com/tidwall/wal v1.1.7/go.mod h1:r6lR1j27W9EPalgHiB7zLJDYu3mzW5BQP5KrzBpYY/E= +github.com/tidwall/wal v1.2.1 h1:xQvwnRF3e+xBC4NvFvl1mPGJHU0aH5zNzlUKnKGIImA= +github.com/tidwall/wal v1.2.1/go.mod h1:r6lR1j27W9EPalgHiB7zLJDYu3mzW5BQP5KrzBpYY/E= github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144 h1:kl4KhGNsJIbDHS9/4U9yQo1UcPQM0kOMJHn29EoH/Ro= github.com/timakin/bodyclose v0.0.0-20210704033933-f49887972144/go.mod h1:Qimiffbc6q9tBWlVV6x0P9sat/ao1xEkREYPPj9hphk= github.com/timonwong/loggercheck v0.9.3 h1:ecACo9fNiHxX4/Bc02rW2+kaJIAMAes7qJ7JKxt0EZI= diff --git a/go.work.sum b/go.work.sum index 87bb9eff79..bcfb1a7ea3 100644 --- a/go.work.sum +++ b/go.work.sum @@ -803,8 +803,6 @@ github.com/schollz/closestmatch v2.1.0+incompatible h1:Uel2GXEpJqOWBrlyI+oY9LTiy github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 h1:nn5Wsu0esKSJiIVhscUtVbo7ada43DJhG55ua/hjS5I= github.com/seccomp/libseccomp-golang v0.9.1 h1:NJjM5DNFOs0s3kYE1WUOr6G8V97sdt46rlXTMfXGWBo= github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM= -github.com/sei-protocol/go-ethereum v1.15.7-sei-15 h1:cK2ZiNo9oWO4LeyRlZYZMILspPq0yoHIJdjLviAxtXE= -github.com/sei-protocol/go-ethereum v1.15.7-sei-15/go.mod h1:+S9k+jFzlyVTNcYGvqFhzN/SFhI6vA+aOY4T5tLSPL0= github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= github.com/shirou/gopsutil/v3 v3.22.9/go.mod h1:bBYl1kjgEJpWpxeHmLI+dVHWtyAwfcmSBLDsp2TNT8A= github.com/shirou/gopsutil/v3 v3.23.2 h1:PAWSuiAszn7IhPMBtXsbSCafej7PqUOvY6YywlQUExU= diff --git a/sei-cosmos/go.mod b/sei-cosmos/go.mod index a4ee0119ab..e0568be7e8 100644 --- a/sei-cosmos/go.mod +++ b/sei-cosmos/go.mod @@ -181,7 +181,7 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/tinylru v1.1.0 // indirect - github.com/tidwall/wal v1.1.7 // indirect + github.com/tidwall/wal v1.2.1 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/urfave/cli/v2 v2.27.5 // indirect diff --git a/sei-cosmos/go.sum b/sei-cosmos/go.sum index 81a38ba25b..4f62c799f8 100644 --- a/sei-cosmos/go.sum +++ b/sei-cosmos/go.sum @@ -1820,8 +1820,7 @@ github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhso github.com/tidwall/sjson v1.1.4/go.mod h1:wXpKXu8CtDjKAZ+3DrKY5ROCorDFahq8l0tey/Lx1fg= github.com/tidwall/tinylru v1.1.0 h1:XY6IUfzVTU9rpwdhKUF6nQdChgCdGjkMfLzbWyiau6I= github.com/tidwall/tinylru v1.1.0/go.mod h1:3+bX+TJ2baOLMWTnlyNWHh4QMnFyARg2TLTQ6OFbzw8= -github.com/tidwall/wal v1.1.7 h1:emc1TRjIVsdKKSnpwGBAcsAGg0767SvUk8+ygx7Bb+4= -github.com/tidwall/wal v1.1.7/go.mod h1:r6lR1j27W9EPalgHiB7zLJDYu3mzW5BQP5KrzBpYY/E= +github.com/tidwall/wal v1.2.1 h1:xQvwnRF3e+xBC4NvFvl1mPGJHU0aH5zNzlUKnKGIImA= github.com/tklauser/go-sysconf v0.3.11/go.mod h1:GqXfhXY3kiPa0nAXPDIQIWzJbMCB7AmcWpGR8lSZfqI= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= diff --git a/sei-db/changelog/changelog/changelog.go b/sei-db/changelog/changelog/changelog.go deleted file mode 100644 index 7ac1951101..0000000000 --- a/sei-db/changelog/changelog/changelog.go +++ /dev/null @@ -1,264 +0,0 @@ -package changelog - -import ( - "errors" - "fmt" - "os" - "path/filepath" - "time" - - "github.com/sei-protocol/sei-chain/sei-db/changelog/types" - errorutils "github.com/sei-protocol/sei-chain/sei-db/common/errors" - "github.com/sei-protocol/sei-chain/sei-db/common/logger" - "github.com/sei-protocol/sei-chain/sei-db/proto" - "github.com/tidwall/wal" -) - -var _ types.Stream[proto.ChangelogEntry] = (*Stream)(nil) - -type Stream struct { - dir string - log *wal.Log - config Config - logger logger.Logger - writeChannel chan *Message - errSignal chan error - nextOffset uint64 - isClosed bool -} - -type Message struct { - Index uint64 - Data *proto.ChangelogEntry -} - -type Config struct { - DisableFsync bool - ZeroCopy bool - WriteBufferSize int - KeepRecent uint64 - PruneInterval time.Duration -} - -// NewStream creates a new changelog stream that persist the changesets in the log -func NewStream(logger logger.Logger, dir string, config Config) (*Stream, error) { - log, err := open(dir, &wal.Options{ - NoSync: config.DisableFsync, - NoCopy: config.ZeroCopy, - }) - if err != nil { - return nil, err - } - stream := &Stream{ - dir: dir, - log: log, - config: config, - logger: logger, - isClosed: false, - } - // Finding the nextOffset to write - lastIndex, err := log.LastIndex() - if err != nil { - return nil, err - } - stream.nextOffset = lastIndex + 1 - // Start the auto pruning goroutine - if config.KeepRecent > 0 { - go stream.StartPruning(config.KeepRecent, config.PruneInterval) - } - return stream, nil - -} - -// Write will write a new entry to the log at given index. -// Whether the writes is in blocking or async manner depends on the buffer size. -func (stream *Stream) Write(offset uint64, entry proto.ChangelogEntry) error { - channelBufferSize := stream.config.WriteBufferSize - if channelBufferSize > 0 { - if stream.writeChannel == nil { - stream.logger.Info(fmt.Sprintf("async write is enabled with buffer size %d", channelBufferSize)) - stream.startWriteGoroutine() - } - // async write - stream.writeChannel <- &Message{Index: offset, Data: &entry} - } else { - // synchronous write - bz, err := entry.Marshal() - if err != nil { - return err - } - if err := stream.log.Write(offset, bz); err != nil { - return err - } - } - return nil -} - -// WriteNextEntry will write a new entry to the last index of the log. -// Whether the writes is in blocking or async manner depends on the buffer size. -func (stream *Stream) WriteNextEntry(entry proto.ChangelogEntry) error { - nextOffset := stream.nextOffset - err := stream.Write(nextOffset, entry) - if err != nil { - return err - } - stream.nextOffset++ - return nil -} - -// startWriteGoroutine will start a goroutine to write entries to the log. -// This should only be called on initialization if async write is enabled -func (stream *Stream) startWriteGoroutine() { - stream.writeChannel = make(chan *Message, stream.config.WriteBufferSize) - stream.errSignal = make(chan error) - go func() { - batch := wal.Batch{} - defer close(stream.errSignal) - for { - entries := channelBatchRecv(stream.writeChannel) - if len(entries) == 0 { - // channel is closed - break - } - - for _, entry := range entries { - bz, err := entry.Data.Marshal() - if err != nil { - stream.errSignal <- err - return - } - batch.Write(entry.Index, bz) - } - - if err := stream.log.WriteBatch(&batch); err != nil { - stream.errSignal <- err - return - } - batch.Clear() - } - }() -} - -// TruncateAfter will remove all entries that are after the provided `index`. -// In other words the entry at `index` becomes the last entry in the log. -func (stream *Stream) TruncateAfter(index uint64) error { - return stream.log.TruncateBack(index) -} - -// TruncateBefore will remove all entries that are before the provided `index`. -// In other words the entry at `index` becomes the first entry in the log. -func (stream *Stream) TruncateBefore(index uint64) error { - return stream.log.TruncateFront(index) -} - -// CheckError check if there's any failed async writes or not -func (stream *Stream) CheckError() error { - select { - case err := <-stream.errSignal: - // async wal writing failed, we need to abort the state machine - return fmt.Errorf("async wal writing goroutine quit unexpectedly: %w", err) - default: - } - return nil -} - -func (stream *Stream) FirstOffset() (index uint64, err error) { - return stream.log.FirstIndex() -} - -// LastOffset returns the last written offset/index of the log -func (stream *Stream) LastOffset() (index uint64, err error) { - return stream.log.LastIndex() -} - -// ReadAt will read the log entry at the provided index -func (stream *Stream) ReadAt(index uint64) (*proto.ChangelogEntry, error) { - var entry = &proto.ChangelogEntry{} - bz, err := stream.log.Read(index) - if err != nil { - return entry, fmt.Errorf("read log failed, %w", err) - } - if err := entry.Unmarshal(bz); err != nil { - return entry, fmt.Errorf("unmarshal rlog failed, %w", err) - } - return entry, nil -} - -// Replay will read the replay log and process each log entry with the provided function -func (stream *Stream) Replay(start uint64, end uint64, processFn func(index uint64, entry proto.ChangelogEntry) error) error { - for i := start; i <= end; i++ { - var entry proto.ChangelogEntry - bz, err := stream.log.Read(i) - if err != nil { - return fmt.Errorf("read log failed, %w", err) - } - if err := entry.Unmarshal(bz); err != nil { - return fmt.Errorf("unmarshal rlog failed, %w", err) - } - err = processFn(i, entry) - if err != nil { - return err - } - } - return nil -} - -func (stream *Stream) StartPruning(keepRecent uint64, pruneInterval time.Duration) { - for !stream.isClosed { - lastIndex, _ := stream.log.LastIndex() - firstIndex, _ := stream.log.FirstIndex() - if lastIndex > keepRecent && (lastIndex-keepRecent) > firstIndex { - prunePos := lastIndex - keepRecent - err := stream.TruncateBefore(prunePos) - stream.logger.Error(fmt.Sprintf("failed to prune changelog till index %d", prunePos), "err", err) - } - time.Sleep(pruneInterval) - } -} - -func (stream *Stream) Close() error { - if stream.writeChannel == nil { - return nil - } - close(stream.writeChannel) - err := <-stream.errSignal - stream.writeChannel = nil - stream.errSignal = nil - errClose := stream.log.Close() - stream.isClosed = true - return errorutils.Join(err, errClose) -} - -// open opens the replay log, try to truncate the corrupted tail if there's any -func open(dir string, opts *wal.Options) (*wal.Log, error) { - if opts == nil { - opts = wal.DefaultOptions - } - rlog, err := wal.Open(dir, opts) - if errors.Is(err, wal.ErrCorrupt) { - // try to truncate corrupted tail - var fis []os.DirEntry - fis, err = os.ReadDir(dir) - if err != nil { - return nil, fmt.Errorf("read wal dir fail: %w", err) - } - var lastSeg string - for _, fi := range fis { - if fi.IsDir() || len(fi.Name()) < 20 { - continue - } - lastSeg = fi.Name() - } - - if len(lastSeg) == 0 { - return nil, err - } - if err = truncateCorruptedTail(filepath.Join(dir, lastSeg), opts.LogFormat); err != nil { - return nil, fmt.Errorf("truncate corrupted tail fail: %w", err) - } - - // try again - return wal.Open(dir, opts) - } - return rlog, err -} diff --git a/sei-db/changelog/changelog/changelog_test.go b/sei-db/changelog/changelog/changelog_test.go deleted file mode 100644 index 860b39f4b9..0000000000 --- a/sei-db/changelog/changelog/changelog_test.go +++ /dev/null @@ -1,178 +0,0 @@ -package changelog - -import ( - "fmt" - "os" - "path/filepath" - "testing" - - "github.com/sei-protocol/sei-chain/sei-db/common/logger" - "github.com/sei-protocol/sei-chain/sei-db/proto" - iavl "github.com/sei-protocol/sei-chain/sei-iavl" - "github.com/stretchr/testify/require" - "github.com/tidwall/wal" -) - -var ( - ChangeSets = []iavl.ChangeSet{ - {Pairs: MockKVPairs("hello", "world")}, - {Pairs: MockKVPairs("hello1", "world1", "hello2", "world2")}, - {Pairs: MockKVPairs("hello3", "world3")}, - } -) - -func TestOpenAndCorruptedTail(t *testing.T) { - opts := &wal.Options{ - LogFormat: wal.JSON, - } - dir := t.TempDir() - - testCases := []struct { - name string - logs []byte - lastIndex uint64 - }{ - {"failure-1", []byte("\n"), 0}, - {"failure-2", []byte(`{}` + "\n"), 0}, - {"failure-3", []byte(`{"index":"1"}` + "\n"), 0}, - {"failure-4", []byte(`{"index":"1","data":"?"}`), 0}, - {"failure-5", []byte(`{"index":1,"data":"?"}` + "\n" + `{"index":"1","data":"?"}`), 1}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - os.WriteFile(filepath.Join(dir, "00000000000000000001"), tc.logs, 0o600) - - _, err := wal.Open(dir, opts) - require.Equal(t, wal.ErrCorrupt, err) - - log, err := open(dir, opts) - require.NoError(t, err) - - lastIndex, err := log.LastIndex() - require.NoError(t, err) - require.Equal(t, tc.lastIndex, lastIndex) - }) - } -} - -func TestReplay(t *testing.T) { - changelog := prepareTestData(t) - var total = 0 - err := changelog.Replay(1, 2, func(index uint64, entry proto.ChangelogEntry) error { - total++ - switch index { - case 1: - require.Equal(t, "test", entry.Changesets[0].Name) - require.Equal(t, []byte("hello"), entry.Changesets[0].Changeset.Pairs[0].Key) - require.Equal(t, []byte("world"), entry.Changesets[0].Changeset.Pairs[0].Value) - case 2: - require.Equal(t, []byte("hello1"), entry.Changesets[0].Changeset.Pairs[0].Key) - require.Equal(t, []byte("world1"), entry.Changesets[0].Changeset.Pairs[0].Value) - require.Equal(t, []byte("hello2"), entry.Changesets[0].Changeset.Pairs[1].Key) - require.Equal(t, []byte("world2"), entry.Changesets[0].Changeset.Pairs[1].Value) - default: - require.Fail(t, fmt.Sprintf("unexpected index %d", index)) - } - return nil - }) - require.NoError(t, err) - require.Equal(t, 2, total) - err = changelog.Close() - require.NoError(t, err) -} - -func TestRandomRead(t *testing.T) { - changelog := prepareTestData(t) - entry, err := changelog.ReadAt(2) - require.NoError(t, err) - require.Equal(t, []byte("hello1"), entry.Changesets[0].Changeset.Pairs[0].Key) - require.Equal(t, []byte("world1"), entry.Changesets[0].Changeset.Pairs[0].Value) - require.Equal(t, []byte("hello2"), entry.Changesets[0].Changeset.Pairs[1].Key) - require.Equal(t, []byte("world2"), entry.Changesets[0].Changeset.Pairs[1].Value) - entry, err = changelog.ReadAt(1) - require.NoError(t, err) - require.Equal(t, []byte("hello"), entry.Changesets[0].Changeset.Pairs[0].Key) - require.Equal(t, []byte("world"), entry.Changesets[0].Changeset.Pairs[0].Value) - entry, err = changelog.ReadAt(3) - require.NoError(t, err) - require.Equal(t, []byte("hello3"), entry.Changesets[0].Changeset.Pairs[0].Key) - require.Equal(t, []byte("world3"), entry.Changesets[0].Changeset.Pairs[0].Value) -} - -func prepareTestData(t *testing.T) *Stream { - dir := t.TempDir() - changelog, err := NewStream(logger.NewNopLogger(), dir, Config{}) - require.NoError(t, err) - writeTestData(changelog) - return changelog -} - -func writeTestData(changelog *Stream) { - for i, changes := range ChangeSets { - cs := []*proto.NamedChangeSet{ - { - Name: "test", - Changeset: changes, - }, - } - entry := &proto.ChangelogEntry{} - entry.Changesets = cs - _ = changelog.Write(uint64(i+1), *entry) - } -} - -func TestSynchronousWrite(t *testing.T) { - changelog := prepareTestData(t) - lastIndex, err := changelog.LastOffset() - require.NoError(t, err) - require.Equal(t, uint64(3), lastIndex) - -} - -func TestAsyncWrite(t *testing.T) { - dir := t.TempDir() - changelog, err := NewStream(logger.NewNopLogger(), dir, Config{WriteBufferSize: 10}) - require.NoError(t, err) - for i, changes := range ChangeSets { - cs := []*proto.NamedChangeSet{ - { - Name: "test", - Changeset: changes, - }, - } - entry := &proto.ChangelogEntry{} - entry.Changesets = cs - err := changelog.Write(uint64(i+1), *entry) - require.NoError(t, err) - } - err = changelog.Close() - require.NoError(t, err) - changelog, err = NewStream(logger.NewNopLogger(), dir, Config{WriteBufferSize: 10}) - require.NoError(t, err) - lastIndex, err := changelog.LastOffset() - require.NoError(t, err) - require.Equal(t, uint64(3), lastIndex) -} - -func TestOpenWithNilOptions(t *testing.T) { - dir := t.TempDir() - - // Test that open function handles nil options correctly - log, err := open(dir, nil) - require.NoError(t, err) - require.NotNil(t, log) - - // Verify the log is functional by checking first and last index - firstIndex, err := log.FirstIndex() - require.NoError(t, err) - require.Equal(t, uint64(0), firstIndex) - - lastIndex, err := log.LastIndex() - require.NoError(t, err) - require.Equal(t, uint64(0), lastIndex) - - // Clean up - err = log.Close() - require.NoError(t, err) -} diff --git a/sei-db/changelog/changelog/subscriber.go b/sei-db/changelog/changelog/subscriber.go deleted file mode 100644 index b8fe896303..0000000000 --- a/sei-db/changelog/changelog/subscriber.go +++ /dev/null @@ -1,86 +0,0 @@ -package changelog - -import ( - "fmt" - - "github.com/sei-protocol/sei-chain/sei-db/changelog/types" - "github.com/sei-protocol/sei-chain/sei-db/proto" -) - -var _ types.Subscriber[proto.ChangelogEntry] = (*Subscriber)(nil) - -type Subscriber struct { - maxPendingSize int - chPendingEntries chan proto.ChangelogEntry - errSignal chan error - stopSignal chan struct{} - processFn func(entry proto.ChangelogEntry) error -} - -func NewSubscriber( - maxPendingSize int, - processFn func(entry proto.ChangelogEntry) error, -) *Subscriber { - subscriber := &Subscriber{ - maxPendingSize: maxPendingSize, - processFn: processFn, - } - - return subscriber -} - -func (s *Subscriber) Start() { - if s.maxPendingSize > 0 { - s.startAsyncProcessing() - } -} - -func (s *Subscriber) ProcessEntry(entry proto.ChangelogEntry) error { - if s.maxPendingSize <= 0 { - return s.processFn(entry) - } - s.chPendingEntries <- entry - return s.CheckError() -} - -func (s *Subscriber) startAsyncProcessing() { - if s.chPendingEntries == nil { - s.chPendingEntries = make(chan proto.ChangelogEntry, s.maxPendingSize) - s.errSignal = make(chan error) - go func() { - defer close(s.errSignal) - for { - select { - case entry := <-s.chPendingEntries: - if err := s.processFn(entry); err != nil { - s.errSignal <- err - } - case <-s.stopSignal: - return - } - } - }() - } -} - -func (s *Subscriber) Close() error { - if s.chPendingEntries != nil { - return nil - } - s.stopSignal <- struct{}{} - close(s.chPendingEntries) - err := s.CheckError() - s.chPendingEntries = nil - s.errSignal = nil - return err -} - -func (s *Subscriber) CheckError() error { - select { - case err := <-s.errSignal: - // async wal writing failed, we need to abort the state machine - return fmt.Errorf("subscriber failed unexpectedly: %w", err) - default: - } - return nil -} diff --git a/sei-db/db_engine/pebbledb/db_test.go b/sei-db/db_engine/pebbledb/db_test.go index f29c61c75c..0a1ce5fdfe 100644 --- a/sei-db/db_engine/pebbledb/db_test.go +++ b/sei-db/db_engine/pebbledb/db_test.go @@ -6,6 +6,7 @@ import ( "github.com/cockroachdb/pebble" "github.com/sei-protocol/sei-chain/sei-db/db_engine" + "github.com/stretchr/testify/require" ) func TestDBGetSetDelete(t *testing.T) { @@ -14,7 +15,7 @@ func TestDBGetSetDelete(t *testing.T) { if err != nil { t.Fatalf("Open: %v", err) } - defer func() { _ = db.Close() }() + t.Cleanup(func() { require.NoError(t, db.Close()) }) key := []byte("k1") val := []byte("v1") @@ -52,10 +53,10 @@ func TestBatchAtomicWrite(t *testing.T) { if err != nil { t.Fatalf("Open: %v", err) } - defer func() { _ = db.Close() }() + t.Cleanup(func() { require.NoError(t, db.Close()) }) b := db.NewBatch() - defer func() { _ = b.Close() }() + t.Cleanup(func() { require.NoError(t, b.Close()) }) if err := b.Set([]byte("a"), []byte("1")); err != nil { t.Fatalf("batch set: %v", err) @@ -91,7 +92,7 @@ func TestIteratorBounds(t *testing.T) { if err != nil { t.Fatalf("Open: %v", err) } - defer func() { _ = db.Close() }() + t.Cleanup(func() { require.NoError(t, db.Close()) }) // Keys: a, b, c for _, k := range []string{"a", "b", "c"} { @@ -104,7 +105,7 @@ func TestIteratorBounds(t *testing.T) { if err != nil { t.Fatalf("NewIter: %v", err) } - defer func() { _ = itr.Close() }() + t.Cleanup(func() { require.NoError(t, itr.Close()) }) var keys []string for ok := itr.First(); ok && itr.Valid(); ok = itr.Next() { @@ -125,7 +126,7 @@ func TestIteratorPrev(t *testing.T) { if err != nil { t.Fatalf("Open: %v", err) } - defer func() { _ = db.Close() }() + t.Cleanup(func() { require.NoError(t, db.Close()) }) // Keys: a, b, c for _, k := range []string{"a", "b", "c"} { @@ -138,7 +139,7 @@ func TestIteratorPrev(t *testing.T) { if err != nil { t.Fatalf("NewIter: %v", err) } - defer func() { _ = itr.Close() }() + t.Cleanup(func() { require.NoError(t, itr.Close()) }) if !itr.Last() || !itr.Valid() { t.Fatalf("expected Last() to position iterator") @@ -190,7 +191,7 @@ func TestIteratorNextPrefixWithComparerSplit(t *testing.T) { if err != nil { t.Fatalf("Open: %v", err) } - defer func() { _ = db.Close() }() + t.Cleanup(func() { require.NoError(t, db.Close()) }) for _, k := range []string{"a/1", "a/2", "a/3", "b/1"} { if err := db.Set([]byte(k), []byte("x"), db_engine.WriteOptions{Sync: false}); err != nil { @@ -202,7 +203,7 @@ func TestIteratorNextPrefixWithComparerSplit(t *testing.T) { if err != nil { t.Fatalf("NewIter: %v", err) } - defer func() { _ = itr.Close() }() + t.Cleanup(func() { require.NoError(t, itr.Close()) }) if !itr.SeekGE([]byte("a/")) || !itr.Valid() { t.Fatalf("expected SeekGE(a/) to be valid") @@ -233,7 +234,7 @@ func TestErrNotFoundConsistency(t *testing.T) { if err != nil { t.Fatalf("Open: %v", err) } - defer func() { _ = db.Close() }() + t.Cleanup(func() { require.NoError(t, db.Close()) }) // Test that Get on missing key returns ErrNotFound _, err = db.Get([]byte("missing-key")) @@ -258,7 +259,7 @@ func TestGetReturnsCopy(t *testing.T) { if err != nil { t.Fatalf("Open: %v", err) } - defer func() { _ = db.Close() }() + t.Cleanup(func() { require.NoError(t, db.Close()) }) key := []byte("k") val := []byte("v") @@ -288,7 +289,7 @@ func TestBatchLenResetDelete(t *testing.T) { if err != nil { t.Fatalf("Open: %v", err) } - defer func() { _ = db.Close() }() + t.Cleanup(func() { require.NoError(t, db.Close()) }) // First, set a key so we can delete it if err := db.Set([]byte("to-delete"), []byte("val"), db_engine.WriteOptions{Sync: false}); err != nil { @@ -296,7 +297,7 @@ func TestBatchLenResetDelete(t *testing.T) { } b := db.NewBatch() - defer func() { _ = b.Close() }() + t.Cleanup(func() { require.NoError(t, b.Close()) }) // Record initial batch len (Pebble batch always has a header, so may not be 0) initialLen := b.Len() @@ -344,7 +345,7 @@ func TestIteratorSeekLTAndValue(t *testing.T) { if err != nil { t.Fatalf("Open: %v", err) } - defer func() { _ = db.Close() }() + t.Cleanup(func() { require.NoError(t, db.Close()) }) // Insert keys: a, b, c with values for _, kv := range []struct{ k, v string }{ @@ -361,7 +362,7 @@ func TestIteratorSeekLTAndValue(t *testing.T) { if err != nil { t.Fatalf("NewIter: %v", err) } - defer func() { _ = itr.Close() }() + t.Cleanup(func() { require.NoError(t, itr.Close()) }) // SeekLT("c") should position at "b" if !itr.SeekLT([]byte("c")) || !itr.Valid() { @@ -381,7 +382,7 @@ func TestFlush(t *testing.T) { if err != nil { t.Fatalf("Open: %v", err) } - defer func() { _ = db.Close() }() + t.Cleanup(func() { require.NoError(t, db.Close()) }) // Set some data if err := db.Set([]byte("flush-test"), []byte("val"), db_engine.WriteOptions{Sync: false}); err != nil { diff --git a/sei-db/db_engine/pebbledb/mvcc/db.go b/sei-db/db_engine/pebbledb/mvcc/db.go index 1f1b72c6cf..c8236022f3 100644 --- a/sei-db/db_engine/pebbledb/mvcc/db.go +++ b/sei-db/db_engine/pebbledb/mvcc/db.go @@ -15,16 +15,17 @@ import ( "github.com/armon/go-metrics" "github.com/cockroachdb/pebble" "github.com/cockroachdb/pebble/bloom" - "github.com/sei-protocol/sei-chain/sei-db/changelog/changelog" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "golang.org/x/exp/slices" + errorutils "github.com/sei-protocol/sei-chain/sei-db/common/errors" "github.com/sei-protocol/sei-chain/sei-db/common/logger" "github.com/sei-protocol/sei-chain/sei-db/common/utils" "github.com/sei-protocol/sei-chain/sei-db/config" "github.com/sei-protocol/sei-chain/sei-db/proto" "github.com/sei-protocol/sei-chain/sei-db/state_db/ss/types" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" - "golang.org/x/exp/slices" + "github.com/sei-protocol/sei-chain/sei-db/wal" ) const ( @@ -65,7 +66,7 @@ type Database struct { storeKeyDirty sync.Map // Changelog used to support async write - streamHandler *changelog.Stream + streamHandler wal.ChangelogWAL // Pending changes to be written to the DB pendingChanges chan VersionedChangesets @@ -150,16 +151,13 @@ func OpenDB(dataDir string, config config.StateStoreConfig) (*Database, error) { _ = db.Close() return nil, errors.New("KeepRecent must be non-negative") } - streamHandler, _ := changelog.NewStream( - logger.NewNopLogger(), - utils.GetChangelogPath(dataDir), - changelog.Config{ - DisableFsync: true, - ZeroCopy: true, - KeepRecent: uint64(config.KeepRecent), - PruneInterval: time.Duration(config.PruneIntervalSeconds) * time.Second, - }, - ) + streamHandler, err := wal.NewChangelogWAL(logger.NewNopLogger(), utils.GetChangelogPath(dataDir), wal.Config{ + KeepRecent: uint64(config.KeepRecent), + PruneInterval: time.Duration(config.PruneIntervalSeconds) * time.Second, + }) + if err != nil { + return nil, err + } database.streamHandler = streamHandler database.asyncWriteWG.Add(1) go database.writeAsyncInBackground() @@ -447,7 +445,7 @@ func (db *Database) ApplyChangesetAsync(version int64, changesets []*proto.Named } entry.Changesets = changesets entry.Upgrades = nil - err := db.streamHandler.WriteNextEntry(entry) + err := db.streamHandler.Write(entry) if err != nil { return err } diff --git a/sei-db/db_engine/rocksdb/mvcc/db.go b/sei-db/db_engine/rocksdb/mvcc/db.go index 580ca3d13d..53f1a67a79 100644 --- a/sei-db/db_engine/rocksdb/mvcc/db.go +++ b/sei-db/db_engine/rocksdb/mvcc/db.go @@ -13,7 +13,8 @@ import ( "time" "github.com/linxGnu/grocksdb" - "github.com/sei-protocol/sei-chain/sei-db/changelog/changelog" + "golang.org/x/exp/slices" + "github.com/sei-protocol/sei-chain/sei-db/common/errors" "github.com/sei-protocol/sei-chain/sei-db/common/logger" "github.com/sei-protocol/sei-chain/sei-db/common/utils" @@ -21,7 +22,7 @@ import ( "github.com/sei-protocol/sei-chain/sei-db/proto" "github.com/sei-protocol/sei-chain/sei-db/state_db/ss/types" "github.com/sei-protocol/sei-chain/sei-db/state_db/ss/util" - "golang.org/x/exp/slices" + "github.com/sei-protocol/sei-chain/sei-db/wal" ) const ( @@ -65,7 +66,7 @@ type Database struct { asyncWriteWG sync.WaitGroup // Changelog used to support async write - streamHandler *changelog.Stream + streamHandler wal.ChangelogWAL // Pending changes to be written to the DB pendingChanges chan VersionedChangesets @@ -112,16 +113,13 @@ func OpenDB(dataDir string, config config.StateStoreConfig) (*Database, error) { } database.latestVersion.Store(latestVersion) - streamHandler, _ := changelog.NewStream( - logger.NewNopLogger(), - utils.GetChangelogPath(dataDir), - changelog.Config{ - DisableFsync: true, - ZeroCopy: true, - KeepRecent: uint64(config.KeepRecent), - PruneInterval: time.Duration(config.PruneIntervalSeconds) * time.Second, - }, - ) + streamHandler, err := wal.NewChangelogWAL(logger.NewNopLogger(), utils.GetChangelogPath(dataDir), wal.Config{ + KeepRecent: uint64(config.KeepRecent), + PruneInterval: time.Duration(config.PruneIntervalSeconds) * time.Second, + }) + if err != nil { + return nil, err + } database.streamHandler = streamHandler go database.writeAsyncInBackground() @@ -263,7 +261,7 @@ func (db *Database) ApplyChangesetAsync(version int64, changesets []*proto.Named } entry.Changesets = changesets entry.Upgrades = nil - err := db.streamHandler.WriteNextEntry(entry) + err := db.streamHandler.Write(entry) if err != nil { return err } diff --git a/sei-db/state_db/sc/memiavl/benchmark_test.go b/sei-db/state_db/sc/memiavl/benchmark_test.go index 25fae1aa04..6c90a394f9 100644 --- a/sei-db/state_db/sc/memiavl/benchmark_test.go +++ b/sei-db/state_db/sc/memiavl/benchmark_test.go @@ -39,7 +39,7 @@ func BenchmarkRandomGet(b *testing.B) { require.NoError(b, err) snapshot, err := OpenSnapshot(snapshotDir, opts) require.NoError(b, err) - defer func() { _ = snapshot.Close() }() + b.Cleanup(func() { require.NoError(b, snapshot.Close()) }) b.Run("memiavl", func(b *testing.B) { require.Equal(b, targetValue, tree.Get(targetKey)) diff --git a/sei-db/state_db/sc/memiavl/db.go b/sei-db/state_db/sc/memiavl/db.go index bdc2c90223..14586895f4 100644 --- a/sei-db/state_db/sc/memiavl/db.go +++ b/sei-db/state_db/sc/memiavl/db.go @@ -8,7 +8,6 @@ import ( "os" "path/filepath" "runtime" - "sort" "strconv" "strings" "sync" @@ -18,12 +17,11 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" - "github.com/sei-protocol/sei-chain/sei-db/changelog/changelog" - "github.com/sei-protocol/sei-chain/sei-db/changelog/types" errorutils "github.com/sei-protocol/sei-chain/sei-db/common/errors" "github.com/sei-protocol/sei-chain/sei-db/common/logger" "github.com/sei-protocol/sei-chain/sei-db/common/utils" "github.com/sei-protocol/sei-chain/sei-db/proto" + "github.com/sei-protocol/sei-chain/sei-db/wal" iavl "github.com/sei-protocol/sei-chain/sei-iavl" ) @@ -55,6 +53,13 @@ type DB struct { readOnly bool opts Options + // streamHandler is the changelog WAL owned by MemIAVL. + // It is opened during OpenDB (if present / allowed) and closed in DB.Close(). + streamHandler wal.ChangelogWAL + // pendingLogEntry accumulates changes (changesets + upgrades) to be written + // into the changelog WAL on the next Commit(). + pendingLogEntry proto.ChangelogEntry + // result channel of snapshot rewrite goroutine snapshotRewriteChan chan snapshotResult // context cancel function to cancel the snapshot rewrite goroutine @@ -72,11 +77,10 @@ type DB struct { // make sure only one snapshot rewrite is running pruneSnapshotLock sync.Mutex - // the changelog stream persists all the changesets - streamHandler types.Stream[proto.ChangelogEntry] - - // pending change, will be written into rlog file in next Commit call - pendingLogEntry proto.ChangelogEntry + // walIndexDelta is the difference: version - walIndex for any entry. + // Since both WAL indices and versions are strictly contiguous, this delta is constant. + // Computed once when opening the DB from the first WAL entry. + walIndexDelta int64 // The assumptions to concurrency: // - The methods on DB are protected by a mutex @@ -199,20 +203,32 @@ func OpenDB(logger logger.Logger, targetVersion int64, opts Options) (database * tree.snapshot.leavesMap.PrepareForRandomRead() } - // Create rlog manager and open the rlog file - streamHandler, err := changelog.NewStream(logger, utils.GetChangelogPath(opts.Dir), changelog.Config{ - DisableFsync: true, - ZeroCopy: true, + // MemIAVL owns changelog lifecycle: always open the WAL here. + // Even in read-only mode we may need WAL replay to reconstruct non-snapshot versions. + streamHandler, err := wal.NewChangelogWAL(logger, utils.GetChangelogPath(opts.Dir), wal.Config{ WriteBufferSize: opts.AsyncCommitBuffer, }) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to open changelog WAL: %w", err) + } + + // Compute WAL index delta (only needed once per DB open) + var walIndexDelta int64 + var walHasEntries bool + walIndexDelta, walHasEntries, err = computeWALIndexDelta(streamHandler) + if err != nil { + return nil, fmt.Errorf("failed to compute WAL index delta: %w", err) + } + // If WAL is empty, set delta so first WAL entry aligns with NextVersion(). + if !walHasEntries { + walIndexDelta = mtree.WorkingCommitInfo().Version - 1 } - if targetVersion == 0 || targetVersion > mtree.Version() { + // Replay WAL to catch up to target version (if WAL has entries) + if walHasEntries && (targetVersion == 0 || targetVersion > mtree.Version()) { logger.Info("Start catching up and replaying the MemIAVL changelog file") - if err := mtree.Catchup(context.Background(), streamHandler, targetVersion); err != nil { - return nil, errorutils.Join(err, streamHandler.Close()) + if err := mtree.Catchup(context.Background(), streamHandler, walIndexDelta, targetVersion); err != nil { + return nil, err } logger.Info(fmt.Sprintf("Finished the replay and caught up to version %d", targetVersion)) } @@ -231,11 +247,16 @@ func OpenDB(logger logger.Logger, targetVersion int64, opts Options) (database * } } - // truncate the rlog file - logger.Info("truncate rlog after version: %d", targetVersion) - truncateIndex := utils.VersionToIndex(targetVersion, mtree.initialVersion.Load()) - if err := streamHandler.TruncateAfter(truncateIndex); err != nil { - return nil, fmt.Errorf("fail to truncate rlog file: %w", err) + // truncate the rlog file (if WAL is provided and has entries) + if walHasEntries { + logger.Info("truncate rlog after version: %d", targetVersion) + // Use O(1) conversion: walIndex = version - delta + truncateIndex := targetVersion - walIndexDelta + if truncateIndex > 0 { + if err := streamHandler.TruncateAfter(uint64(truncateIndex)); err != nil { + return nil, fmt.Errorf("fail to truncate rlog file: %w", err) + } + } } // prune snapshots that's larger than the target version @@ -254,6 +275,7 @@ func OpenDB(logger logger.Logger, targetVersion int64, opts Options) (database * return nil, fmt.Errorf("fail to prune snapshots: %w", err) } } + // create worker pool. recv tasks to write snapshot workerPool := pond.New(opts.SnapshotWriterLimit, opts.SnapshotWriterLimit*10) @@ -268,6 +290,7 @@ func OpenDB(logger logger.Logger, targetVersion int64, opts Options) (database * dir: opts.Dir, fileLock: fileLock, readOnly: opts.ReadOnly, + walIndexDelta: walIndexDelta, streamHandler: streamHandler, snapshotKeepRecent: opts.SnapshotKeepRecent, snapshotInterval: opts.SnapshotInterval, @@ -277,20 +300,34 @@ func OpenDB(logger logger.Logger, targetVersion int64, opts Options) (database * opts: opts, } - if !db.readOnly && db.Version() == 0 && len(opts.InitialStores) > 0 { - // do the initial upgrade with the `opts.InitialStores` + // Apply initial stores on a fresh DB (version 0) so they get persisted to WAL. + // This creates the trees and populates pendingLogEntry, which will be written + // to WAL on the first Commit(). + // ApplyUpgrades is idempotent (skips existing trees), so this is safe. + if !opts.ReadOnly && db.Version() == 0 && len(opts.InitialStores) > 0 { var upgrades []*proto.TreeNameUpgrade for _, name := range opts.InitialStores { upgrades = append(upgrades, &proto.TreeNameUpgrade{Name: name}) } if err := db.ApplyUpgrades(upgrades); err != nil { - return nil, errorutils.Join(err, db.Close()) + return nil, fmt.Errorf("failed to apply initial stores: %w", err) } } return db, nil } +// GetWAL returns the WAL handler for changelog operations. +func (db *DB) GetWAL() wal.ChangelogWAL { + return db.streamHandler +} + +// GetWALIndexDelta returns the precomputed delta between version and WAL index. +// This allows O(1) conversion: version = walIndex + delta, walIndex = version - delta +func (db *DB) GetWALIndexDelta() int64 { + return db.walIndexDelta +} + func removeTmpDirs(rootDir string) error { entries, err := os.ReadDir(rootDir) if err != nil { @@ -337,8 +374,7 @@ func (db *DB) SetInitialVersion(initialVersion int64) error { return initEmptyDB(db.dir, db.initialVersion.Load()) } -// ApplyUpgrades wraps MultiTree.ApplyUpgrades, it also appends the upgrades in a pending log, -// which will be persisted to the rlog in next Commit call. +// ApplyUpgrades wraps MultiTree.ApplyUpgrades to add a lock. func (db *DB) ApplyUpgrades(upgrades []*proto.TreeNameUpgrade) error { db.mtx.Lock() defer db.mtx.Unlock() @@ -347,16 +383,13 @@ func (db *DB) ApplyUpgrades(upgrades []*proto.TreeNameUpgrade) error { return errReadOnly } - if err := db.MultiTree.ApplyUpgrades(upgrades); err != nil { - return err + if len(upgrades) > 0 { + db.pendingLogEntry.Upgrades = append(db.pendingLogEntry.Upgrades, upgrades...) } - - db.pendingLogEntry.Upgrades = append(db.pendingLogEntry.Upgrades, upgrades...) - return nil + return db.MultiTree.ApplyUpgrades(upgrades) } -// ApplyChangeSets wraps MultiTree.ApplyChangeSets, it also appends the changesets in the pending log, -// which will be persisted to the rlog in next Commit call. +// ApplyChangeSets wraps MultiTree.ApplyChangeSets to add a lock. func (db *DB) ApplyChangeSets(changeSets []*proto.NamedChangeSet) (_err error) { if len(changeSets) == 0 { return nil @@ -378,16 +411,12 @@ func (db *DB) ApplyChangeSets(changeSets []*proto.NamedChangeSet) (_err error) { return errReadOnly } - if len(db.pendingLogEntry.Changesets) > 0 { - return errors.New("don't support multiple ApplyChangeSets calls in the same version") - } + // Overwrite pending changesets for this commit; callers typically provide them once per block. db.pendingLogEntry.Changesets = changeSets - return db.MultiTree.ApplyChangeSets(changeSets) } -// ApplyChangeSet wraps MultiTree.ApplyChangeSet, it also appends the changesets in the pending log, -// which will be persisted to the rlog in next Commit call. +// ApplyChangeSet wraps MultiTree.ApplyChangeSet to add a lock. func (db *DB) ApplyChangeSet(name string, changeSet iavl.ChangeSet) error { if len(changeSet.Pairs) == 0 { return nil @@ -400,41 +429,28 @@ func (db *DB) ApplyChangeSet(name string, changeSet iavl.ChangeSet) error { return errReadOnly } - for _, cs := range db.pendingLogEntry.Changesets { - if cs.Name == name { - return errors.New("don't support multiple ApplyChangeSet calls with the same name in the same version") - } - } - db.pendingLogEntry.Changesets = append(db.pendingLogEntry.Changesets, &proto.NamedChangeSet{ Name: name, Changeset: changeSet, }) - sort.SliceStable(db.pendingLogEntry.Changesets, func(i, j int) bool { - return db.pendingLogEntry.Changesets[i].Name < db.pendingLogEntry.Changesets[j].Name - }) - return db.MultiTree.ApplyChangeSet(name, changeSet) } // checkAsyncTasks checks the status of background tasks non-blocking-ly and process the result func (db *DB) checkAsyncTasks() error { - return errorutils.Join( - db.streamHandler.CheckError(), - db.checkBackgroundSnapshotRewrite(), - ) + return db.checkBackgroundSnapshotRewrite() } -// CommittedVersion returns the latest version written in rlog file, or snapshot version if rlog is empty. +// CommittedVersion returns the current version of the MultiTree. func (db *DB) CommittedVersion() (int64, error) { - lastIndex, err := db.streamHandler.LastOffset() + lastOffset, err := db.GetWAL().LastOffset() if err != nil { return 0, err } - if lastIndex == 0 { + if lastOffset == 0 { return db.SnapshotVersion(), nil } - return utils.IndexToVersion(lastIndex, db.initialVersion.Load()), nil + return db.walIndexToVersion(lastOffset), nil } // checkBackgroundSnapshotRewrite check the result of background snapshot rewrite, cleans up the old snapshots and switches to a new multitree @@ -455,8 +471,8 @@ func (db *DB) checkBackgroundSnapshotRewrite() error { return fmt.Errorf("background snapshot rewriting failed: %w", result.err) } - // wait for potential pending rlog writings to finish, to make sure we catch up to latest state. - // in real world, block execution should be slower than rlog writing, so this should not block for long. + // wait for potential pending writes to finish, to make sure we catch up to latest state. + // in real world, block execution should be slower than tree updates, so this should not block for long. for { committedVersion, err := db.CommittedVersion() if err != nil { @@ -469,8 +485,10 @@ func (db *DB) checkBackgroundSnapshotRewrite() error { } // catchup the remaining entries in rlog - if err := result.mtree.Catchup(context.Background(), db.streamHandler, 0); err != nil { - return fmt.Errorf("catchup failed: %w", err) + if wal := db.GetWAL(); wal != nil { + if err := result.mtree.Catchup(context.Background(), wal, db.walIndexDelta, 0); err != nil { + return fmt.Errorf("catchup failed: %w", err) + } } // do the switch @@ -489,7 +507,8 @@ func (db *DB) checkBackgroundSnapshotRewrite() error { return nil } -// pruneSnapshot prune the old snapshots +// pruneSnapshots prunes old snapshots, keeping only snapshotKeepRecent recent ones. +// Note: WAL truncation is now handled by CommitStore after each commit. func (db *DB) pruneSnapshots() { // wait until last prune finish db.pruneSnapshotLock.Lock() @@ -525,19 +544,54 @@ func (db *DB) pruneSnapshots() { db.logger.Error("fail to prune snapshots", "err", err) return } +} - // truncate Rlog until the earliest remaining snapshot - earliestVersion, err := GetEarliestVersion(db.dir) +// computeWALIndexDelta computes the constant delta between version and WAL index. +// Since both are strictly contiguous, we only need to read one entry. +// Returns (delta, hasEntries, error). hasEntries is false if WAL is empty. +func computeWALIndexDelta(stream wal.ChangelogWAL) (int64, bool, error) { + firstIndex, err := stream.FirstOffset() if err != nil { - db.logger.Error("failed to find first snapshot", "err", err) + return 0, false, err + } + if firstIndex == 0 { + return 0, false, nil // empty WAL } - if err := db.streamHandler.TruncateBefore(utils.VersionToIndex(earliestVersion+1, db.initialVersion.Load())); err != nil { - db.logger.Error("failed to truncate rlog", "err", err, "version", earliestVersion+1) + // Read just the first entry to compute delta + var firstVersion int64 + err = stream.Replay(firstIndex, firstIndex, func(index uint64, entry proto.ChangelogEntry) error { + firstVersion = entry.Version + return nil + }) + if err != nil { + return 0, false, err + } + + // delta = version - index, so for any entry: version = index + delta + // #nosec G115 -- WAL indices are always much smaller than MaxInt64 in practice + return firstVersion - int64(firstIndex), true, nil +} + +// versionToWALIndex converts a version to its corresponding WAL index using the precomputed delta. +// Returns 0 if the version would result in an invalid (negative or zero) index. +func (db *DB) versionToWALIndex(version int64) uint64 { + index := version - db.walIndexDelta + if index <= 0 { + return 0 } + // #nosec G115 -- index is guaranteed positive by the check above + return uint64(index) +} + +// walIndexToVersion converts a WAL index to its corresponding version using the precomputed delta. +func (db *DB) walIndexToVersion(index uint64) int64 { + // #nosec G115 -- WAL indices are always much smaller than MaxInt64 in practice + return int64(index) + db.walIndexDelta } -// Commit wraps SaveVersion to bump the version and writes the pending changes into log files to persist on disk +// Commit wraps SaveVersion to bump the version and finalize the tree state. +// MemIAVL owns the changelog: it writes the pending changelog entry before committing the tree. func (db *DB) Commit() (version int64, _err error) { startTime := time.Now() defer func() { @@ -557,20 +611,29 @@ func (db *DB) Commit() (version int64, _err error) { return 0, errReadOnly } + // Commit the in-memory tree state FIRST. + // MemIAVL is purely in-memory; SaveVersion() doesn't persist anything. + // The changelog WAL is our persistence layer. v, err := db.MultiTree.SaveVersion(true) if err != nil { return 0, err } - // write to changelog - if db.streamHandler != nil { - db.pendingLogEntry.Version = v - err := db.streamHandler.Write(utils.VersionToIndex(v, db.initialVersion.Load()), db.pendingLogEntry) - if err != nil { - return 0, err + // Write to WAL AFTER successful SaveVersion. + // Rationale: If SaveVersion fails but we already wrote to WAL, we'd have + // a WAL entry for a version that was never committed. On replay, this would + // corrupt state. By writing WAL after SaveVersion succeeds, we ensure WAL + // only contains valid committed versions. If WAL write fails after SaveVersion, + // we lose this version on crash (rollback to prior state), but remain consistent. + // + // Note: Write() automatically checks for any previous async write errors. + if wal := db.GetWAL(); wal != nil { + entry := db.pendingLogEntry + entry.Version = v + if err := wal.Write(entry); err != nil { + return 0, fmt.Errorf("failed to write changelog WAL: %w", err) } } - db.pendingLogEntry = proto.ChangelogEntry{} if err := db.checkAsyncTasks(); err != nil { @@ -579,10 +642,41 @@ func (db *DB) Commit() (version int64, _err error) { // Rewrite tree snapshot if applicable db.rewriteIfApplicable(v) + db.tryTruncateWAL() return v, nil } +// tryTruncateWAL best-effort truncates old WAL entries that are older than the earliest snapshot. +func (db *DB) tryTruncateWAL() { + if db.streamHandler == nil { + return + } + firstWALIndex, err := db.streamHandler.FirstOffset() + if err != nil || firstWALIndex == 0 { + return + } + earliestSnapshotVersion, err := GetEarliestVersion(db.dir) + if err != nil { + return + } + if firstWALIndex > uint64(math.MaxInt64) { + db.logger.Error("WAL first offset overflows int64; skipping truncation", "firstWALIndex", firstWALIndex) + return + } + walEarliestVersion := db.walIndexToVersion(firstWALIndex) + if walEarliestVersion >= earliestSnapshotVersion { + return + } + truncateIndex := db.versionToWALIndex(earliestSnapshotVersion) + if truncateIndex == 0 || truncateIndex <= firstWALIndex { + return + } + if err := db.streamHandler.TruncateBefore(truncateIndex); err != nil { + db.logger.Error("failed to truncate changelog WAL", "err", err, "truncateIndex", truncateIndex) + } +} + func (db *DB) Copy() *DB { db.mtx.Lock() defer db.mtx.Unlock() @@ -697,10 +791,8 @@ func (db *DB) reload() error { } func (db *DB) reloadMultiTree(mtree *MultiTree) error { - // catch-up the pending changes - if err := mtree.apply(db.pendingLogEntry); err != nil { - return err - } + // The caller is responsible for ensuring mtree is caught up to the latest state + // (either via Catchup from WAL or by loading a current snapshot). return db.ReplaceWith(mtree) } @@ -811,13 +903,15 @@ func (db *DB) rewriteSnapshotBackground() error { cloned.logger.Info("loaded multitree after snapshot", "elapsed", time.Since(loadStart).Seconds()) // do a best effort catch-up, will do another final catch-up in main thread. - catchupStart := time.Now() - if err := mtree.Catchup(ctx, db.streamHandler, 0); err != nil { - cloned.logger.Error("failed to catchup after snapshot", "error", err) - ch <- snapshotResult{err: err} - return + if wal := db.GetWAL(); wal != nil { + catchupStart := time.Now() + if err := mtree.Catchup(ctx, wal, db.walIndexDelta, 0); err != nil { + cloned.logger.Error("failed to catchup after snapshot", "error", err) + ch <- snapshotResult{err: err} + return + } + cloned.logger.Info("finished best-effort catchup", "version", cloned.Version(), "latest", mtree.Version(), "elapsed", time.Since(catchupStart).Seconds()) } - cloned.logger.Info("finished best-effort catchup", "version", cloned.Version(), "latest", mtree.Version(), "elapsed", time.Since(catchupStart).Seconds()) ch <- snapshotResult{mtree: mtree} totalElapsed := time.Since(startTime).Seconds() @@ -839,8 +933,7 @@ func (db *DB) Close() error { db.pruneSnapshotLock.Lock() defer db.pruneSnapshotLock.Unlock() - // Close rewrite channel first - must wait for background goroutine before closing streamHandler - // because the goroutine may still be using streamHandler + // Close rewrite channel first - must wait for background goroutine before closing WAL db.logger.Info("Closing rewrite channel...") if db.snapshotRewriteChan != nil { db.snapshotRewriteCancelFunc() @@ -861,11 +954,9 @@ func (db *DB) Close() error { db.snapshotRewriteCancelFunc = nil } - // Close stream handler after background goroutine has finished - db.logger.Info("Closing stream handler...") + // Close WAL after snapshot rewrite goroutine has fully exited. if db.streamHandler != nil { - err := db.streamHandler.Close() - errs = append(errs, err) + errs = append(errs, db.streamHandler.Close()) db.streamHandler = nil } @@ -1138,7 +1229,7 @@ func GetLatestVersion(dir string) (int64, error) { } return 0, err } - lastIndex, err := changelog.GetLastIndex(changelog.LogPath(dir)) + lastIndex, err := wal.GetLastIndex(wal.LogPath(dir)) if err != nil { return 0, err } diff --git a/sei-db/state_db/sc/memiavl/db_rewrite_test.go b/sei-db/state_db/sc/memiavl/db_rewrite_test.go index e2ff7fd06e..062bbc8e80 100644 --- a/sei-db/state_db/sc/memiavl/db_rewrite_test.go +++ b/sei-db/state_db/sc/memiavl/db_rewrite_test.go @@ -81,7 +81,7 @@ func TestLoadMultiTreeWithPrefetchDisabled(t *testing.T) { db2, err := OpenDB(logger.NewNopLogger(), 0, opts) require.NoError(t, err) - defer db2.Close() + t.Cleanup(func() { require.NoError(t, db2.Close()) }) // Verify data is accessible tree := db2.TreeByName("test") diff --git a/sei-db/state_db/sc/memiavl/db_test.go b/sei-db/state_db/sc/memiavl/db_test.go index 07e477ff58..a0d822fd03 100644 --- a/sei-db/state_db/sc/memiavl/db_test.go +++ b/sei-db/state_db/sc/memiavl/db_test.go @@ -11,12 +11,13 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/sei-protocol/sei-chain/sei-db/common/errors" "github.com/sei-protocol/sei-chain/sei-db/common/logger" "github.com/sei-protocol/sei-chain/sei-db/common/utils" "github.com/sei-protocol/sei-chain/sei-db/proto" iavl "github.com/sei-protocol/sei-chain/sei-iavl" - "github.com/stretchr/testify/require" ) func TestRewriteSnapshot(t *testing.T) { @@ -26,7 +27,7 @@ func TestRewriteSnapshot(t *testing.T) { InitialStores: []string{"test"}, }) require.NoError(t, err) - defer db.Close() // Ensure DB cleanup + t.Cleanup(func() { require.NoError(t, db.Close()) }) // Ensure DB cleanup for i, changes := range ChangeSets { cs := []*proto.NamedChangeSet{ @@ -50,7 +51,6 @@ func TestRewriteSnapshot(t *testing.T) { func TestRemoveSnapshotDir(t *testing.T) { dbDir := t.TempDir() - defer os.RemoveAll(dbDir) snapshotDir := filepath.Join(dbDir, snapshotName(0)) tmpDir := snapshotDir + "-tmp" @@ -100,7 +100,7 @@ func TestRewriteSnapshotBackground(t *testing.T) { SnapshotKeepRecent: 0, // only a single snapshot is kept }) require.NoError(t, err) - defer db.Close() // Ensure DB cleanup and goroutine termination + t.Cleanup(func() { require.NoError(t, db.Close()) }) // Ensure DB cleanup and goroutine termination // spin up goroutine to keep querying the tree stopCh := make(chan struct{}) @@ -159,7 +159,7 @@ func TestRewriteSnapshotBackground(t *testing.T) { entries, err := os.ReadDir(db.dir) require.NoError(t, err) - // three files: snapshot, current link, rlog, LOCK + // snapshot, current link, LOCK, changelog WAL dir require.Equal(t, 4, len(entries)) // stopCh is closed by defer above } @@ -229,10 +229,12 @@ func TestSnapshotTriggerOnIntervalDiff(t *testing.T) { func TestRlog(t *testing.T) { dir := t.TempDir() + initialStores := []string{"test", "delete"} + db, err := OpenDB(logger.NewNopLogger(), 0, Options{ Dir: dir, CreateIfMissing: true, - InitialStores: []string{"test", "delete"}, + InitialStores: initialStores, }) require.NoError(t, err) @@ -250,7 +252,7 @@ func TestRlog(t *testing.T) { require.Equal(t, 2, len(db.lastCommitInfo.StoreInfos)) - require.NoError(t, db.ApplyUpgrades([]*proto.TreeNameUpgrade{ + upgrades := []*proto.TreeNameUpgrade{ { Name: "newtest", RenameFrom: "test", @@ -259,15 +261,17 @@ func TestRlog(t *testing.T) { Name: "delete", Delete: true, }, - })) + } + require.NoError(t, db.ApplyUpgrades(upgrades)) _, err = db.Commit() require.NoError(t, err) require.NoError(t, db.Close()) - db, err = OpenDB(logger.NewNopLogger(), 0, Options{Dir: dir}) + // Reopen (MemIAVL will open the changelog from disk) + db, err = OpenDB(logger.NewNopLogger(), 0, Options{Dir: dir, InitialStores: initialStores}) require.NoError(t, err) - defer db.Close() // Close the reopened DB + t.Cleanup(func() { require.NoError(t, db.Close()) }) // Close the reopened DB require.Equal(t, "newtest", db.lastCommitInfo.StoreInfos[0].Name) require.Equal(t, 1, len(db.lastCommitInfo.StoreInfos)) @@ -296,14 +300,17 @@ func TestInitialVersion(t *testing.T) { value := "world" for _, initialVersion := range []int64{0, 1, 100} { dir := t.TempDir() + initialStores := []string{name} + db, err := OpenDB(logger.NewNopLogger(), 0, Options{ Dir: dir, CreateIfMissing: true, - InitialStores: []string{name}, + InitialStores: initialStores, }) require.NoError(t, err) db.SetInitialVersion(initialVersion) - require.NoError(t, db.ApplyChangeSets(mockNameChangeSet(name, key, value))) + cs1 := mockNameChangeSet(name, key, value) + require.NoError(t, db.ApplyChangeSets(cs1)) v, err := db.Commit() require.NoError(t, err) if initialVersion <= 1 { @@ -313,7 +320,8 @@ func TestInitialVersion(t *testing.T) { } hash := db.LastCommitInfo().StoreInfos[0].CommitId.Hash require.Equal(t, "6032661ab0d201132db7a8fa1da6a0afe427e6278bd122c301197680ab79ca02", hex.EncodeToString(hash)) - require.NoError(t, db.ApplyChangeSets(mockNameChangeSet(name, key, "world1"))) + cs2 := mockNameChangeSet(name, key, "world1") + require.NoError(t, db.ApplyChangeSets(cs2)) v, err = db.Commit() require.NoError(t, err) hash = db.LastCommitInfo().StoreInfos[0].CommitId.Hash @@ -326,15 +334,17 @@ func TestInitialVersion(t *testing.T) { } require.NoError(t, db.Close()) - db, err = OpenDB(logger.NewNopLogger(), 0, Options{Dir: dir}) + // Reopen (MemIAVL will open the changelog from disk) + db, err = OpenDB(logger.NewNopLogger(), 0, Options{Dir: dir, InitialStores: initialStores}) require.NoError(t, err) - defer db.Close() // Close the reopened DB at end of loop iteration require.Equal(t, uint32(initialVersion), db.initialVersion.Load()) require.Equal(t, v, db.Version()) require.Equal(t, hex.EncodeToString(hash), hex.EncodeToString(db.LastCommitInfo().StoreInfos[0].CommitId.Hash)) - db.ApplyUpgrades([]*proto.TreeNameUpgrade{{Name: name1}}) - require.NoError(t, db.ApplyChangeSets(mockNameChangeSet(name1, key, value))) + upgrades1 := []*proto.TreeNameUpgrade{{Name: name1}} + db.ApplyUpgrades(upgrades1) + cs3 := mockNameChangeSet(name1, key, value) + require.NoError(t, db.ApplyChangeSets(cs3)) v, err = db.Commit() require.NoError(t, err) if initialVersion <= 1 { @@ -353,8 +363,10 @@ func TestInitialVersion(t *testing.T) { require.NoError(t, db.RewriteSnapshot(context.Background())) require.NoError(t, db.Reload()) - db.ApplyUpgrades([]*proto.TreeNameUpgrade{{Name: name2}}) - require.NoError(t, db.ApplyChangeSets(mockNameChangeSet(name2, key, value))) + upgrades2 := []*proto.TreeNameUpgrade{{Name: name2}} + db.ApplyUpgrades(upgrades2) + cs4 := mockNameChangeSet(name2, key, value) + require.NoError(t, db.ApplyChangeSets(cs4)) v, err = db.Commit() require.NoError(t, err) if initialVersion <= 1 { @@ -367,15 +379,19 @@ func TestInitialVersion(t *testing.T) { require.Equal(t, name2, info2.Name) require.Equal(t, v, info2.CommitId.Version) require.Equal(t, hex.EncodeToString(info.CommitId.Hash), hex.EncodeToString(info2.CommitId.Hash)) + + require.NoError(t, db.Close()) } } func TestLoadVersion(t *testing.T) { dir := t.TempDir() + initialStores := []string{"test"} + db, err := OpenDB(logger.NewNopLogger(), 0, Options{ Dir: dir, CreateIfMissing: true, - InitialStores: []string{"test"}, + InitialStores: initialStores, }) require.NoError(t, err) @@ -391,7 +407,6 @@ func TestLoadVersion(t *testing.T) { // check the root hash require.Equal(t, RefHashes[db.Version()], db.WorkingCommitInfo().StoreInfos[0].CommitId.Hash) - _, err := db.Commit() require.NoError(t, err) }) @@ -402,9 +417,11 @@ func TestLoadVersion(t *testing.T) { if v == 0 { continue } + // Read-only loads use the same WAL to replay tmp, err := OpenDB(logger.NewNopLogger(), int64(v), Options{ - Dir: dir, - ReadOnly: true, + Dir: dir, + ReadOnly: true, + InitialStores: initialStores, }) require.NoError(t, err) require.Equal(t, RefHashes[v-1], tmp.TreeByName("test").RootHash()) @@ -481,17 +498,225 @@ func TestRlogIndexConversion(t *testing.T) { } } +// Regression test: on a fresh DB (version 0), the initial snapshot can contain 0 trees, +// but WAL replay may already contain changesets for initial store names. OpenDB must +// TestWALIndexDeltaComputation tests the O(1) delta-based WAL index conversion. +// This is critical because: +// 1. WAL indices and versions are both strictly contiguous +// 2. We compute delta once from the first WAL entry: delta = firstVersion - firstIndex +// 3. All conversions are then O(1): walIndex = version - delta +func TestWALIndexDeltaComputation(t *testing.T) { + testCases := []struct { + name string + initialVersion uint32 + numVersions int + rollbackTo int64 + }{ + { + name: "Test wal delta=0 and version = 1", + initialVersion: 0, + numVersions: 5, + rollbackTo: 3, + }, + { + name: "Test wal delta=9 and version = 10", + initialVersion: 10, + numVersions: 5, + rollbackTo: 12, + }, + { + name: "Test wal delta=99 and version = 100", + initialVersion: 100, + numVersions: 5, + rollbackTo: 102, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dir := t.TempDir() + initialStores := []string{"test"} + + // Open DB with initial version + db, err := OpenDB(logger.NewNopLogger(), 0, Options{ + Dir: dir, + CreateIfMissing: true, + InitialStores: initialStores, + InitialVersion: tc.initialVersion, + }) + require.NoError(t, err) + + // Commit multiple versions + for i := 0; i < tc.numVersions; i++ { + cs := []*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value" + strconv.Itoa(i))}, + }, + }, + }, + } + require.NoError(t, db.ApplyChangeSets(cs)) + _, err = db.Commit() + require.NoError(t, err) + } + + // When initialVersion=0, first commit is version 1, so after N commits: version = N + // When initialVersion=X, first commit is version X, so after N commits: version = X + N - 1 + expectedVersion := int64(tc.numVersions) + if tc.initialVersion > 0 { + expectedVersion = int64(tc.initialVersion) + int64(tc.numVersions) - 1 + } + require.Equal(t, expectedVersion, db.Version()) + + require.NoError(t, db.Close()) + + // Reopen to verify delta is computed correctly from WAL entries + dbReopen, err := OpenDB(logger.NewNopLogger(), 0, Options{ + Dir: dir, + InitialStores: initialStores, + }) + require.NoError(t, err) + + // Now verify delta is computed correctly + // delta = firstVersion - firstIndex + // When initialVersion=0: firstVersion = 1, firstIndex = 1, delta = 0 + // When initialVersion=X: firstVersion = X, firstIndex = 1, delta = X - 1 + expectedDelta := int64(0) + if tc.initialVersion > 0 { + expectedDelta = int64(tc.initialVersion) - 1 + } + require.Equal(t, expectedDelta, dbReopen.walIndexDelta, "WAL index delta should be computed correctly") + + // Test versionToWALIndex + for i := 0; i < tc.numVersions; i++ { + var version int64 + if tc.initialVersion == 0 { + version = int64(i + 1) // versions: 1, 2, 3, 4, 5 + } else { + version = int64(tc.initialVersion) + int64(i) // versions: 10, 11, 12, 13, 14 + } + expectedIndex := uint64(i + 1) // WAL indices: 1, 2, 3, 4, 5 + require.Equal(t, expectedIndex, dbReopen.versionToWALIndex(version), + "versionToWALIndex(%d) should return %d", version, expectedIndex) + } + + require.NoError(t, dbReopen.Close()) + + // Now test rollback with LoadForOverwriting + db2, err := OpenDB(logger.NewNopLogger(), tc.rollbackTo, Options{ + Dir: dir, + InitialStores: initialStores, + LoadForOverwriting: true, + }) + require.NoError(t, err) + + // Verify rollback worked + require.Equal(t, tc.rollbackTo, db2.Version(), "Version should be rolled back to %d", tc.rollbackTo) + + // Verify WAL was truncated correctly + lastIndex, err := db2.GetWAL().LastOffset() + require.NoError(t, err) + expectedLastIndex := uint64(tc.rollbackTo - db2.walIndexDelta) + require.Equal(t, expectedLastIndex, lastIndex, "WAL should be truncated to index %d", expectedLastIndex) + + require.NoError(t, db2.Close()) + + // Reopen without LoadForOverwriting to verify persistence + db3, err := OpenDB(logger.NewNopLogger(), 0, Options{ + Dir: dir, + InitialStores: initialStores, + }) + require.NoError(t, err) + require.Equal(t, tc.rollbackTo, db3.Version(), "Version should persist as %d after reopen", tc.rollbackTo) + + require.NoError(t, db3.Close()) + }) + } +} + +// TestWALIndexDeltaWithZeroDelta specifically tests the case where delta=0. +// This was a bug where `walIndexDelta != 0` condition incorrectly skipped truncation +// when versions started at 1 (making delta = 1 - 1 = 0). +func TestWALIndexDeltaWithZeroDelta(t *testing.T) { + dir := t.TempDir() + initialStores := []string{"test"} + + // Create DB with default initial version (0, so versions start at 1) + db, err := OpenDB(logger.NewNopLogger(), 0, Options{ + Dir: dir, + CreateIfMissing: true, + InitialStores: initialStores, + }) + require.NoError(t, err) + + // Commit 5 versions (1, 2, 3, 4, 5) + for i := 0; i < 5; i++ { + cs := []*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value" + strconv.Itoa(i))}, + }, + }, + }, + } + require.NoError(t, db.ApplyChangeSets(cs)) + _, err = db.Commit() + require.NoError(t, err) + } + + require.Equal(t, int64(5), db.Version()) + // Critical: delta should be 0 (version 1 - index 1 = 0) + require.Equal(t, int64(0), db.walIndexDelta, "Delta should be 0 when versions start at 1") + + require.NoError(t, db.Close()) + + // Rollback to version 3 + db2, err := OpenDB(logger.NewNopLogger(), 3, Options{ + Dir: dir, + InitialStores: initialStores, + LoadForOverwriting: true, + }) + require.NoError(t, err) + + // This is the key assertion that would have failed with the bug + require.Equal(t, int64(3), db2.Version(), "Rollback should work even when delta=0") + + // Verify WAL truncation + lastIndex, err := db2.GetWAL().LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(3), lastIndex, "WAL should be truncated to index 3") + + require.NoError(t, db2.Close()) + + // Verify rollback persisted after reopen + db3, err := OpenDB(logger.NewNopLogger(), 0, Options{ + Dir: dir, + InitialStores: initialStores, + }) + require.NoError(t, err) + require.Equal(t, int64(3), db3.Version(), "Rollback should persist after reopen") + + require.NoError(t, db3.Close()) +} + func TestEmptyValue(t *testing.T) { dir := t.TempDir() + initialStores := []string{"test"} + db, err := OpenDB(logger.NewNopLogger(), 0, Options{ Dir: dir, - InitialStores: []string{"test"}, + InitialStores: initialStores, CreateIfMissing: true, ZeroCopy: true, }) require.NoError(t, err) - require.NoError(t, db.ApplyChangeSets([]*proto.NamedChangeSet{ + cs1 := []*proto.NamedChangeSet{ {Name: "test", Changeset: iavl.ChangeSet{ Pairs: []*iavl.KVPair{ {Key: []byte("hello1"), Value: []byte("")}, @@ -499,23 +724,26 @@ func TestEmptyValue(t *testing.T) { {Key: []byte("hello3"), Value: []byte("")}, }, }}, - })) + } + require.NoError(t, db.ApplyChangeSets(cs1)) _, err = db.Commit() require.NoError(t, err) - require.NoError(t, db.ApplyChangeSets([]*proto.NamedChangeSet{ + cs2 := []*proto.NamedChangeSet{ {Name: "test", Changeset: iavl.ChangeSet{ Pairs: []*iavl.KVPair{{Key: []byte("hello1"), Delete: true}}, }}, - })) + } + require.NoError(t, db.ApplyChangeSets(cs2)) version, err := db.Commit() require.NoError(t, err) require.NoError(t, db.Close()) - db, err = OpenDB(logger.NewNopLogger(), 0, Options{Dir: dir, ZeroCopy: true}) + // Reopen (MemIAVL will open the changelog from disk) + db, err = OpenDB(logger.NewNopLogger(), 0, Options{Dir: dir, ZeroCopy: true, InitialStores: initialStores}) require.NoError(t, err) - defer db.Close() // Close the reopened DB + t.Cleanup(func() { require.NoError(t, db.Close()) }) // Close the reopened DB require.Equal(t, version, db.Version()) } @@ -623,14 +851,16 @@ func TestRepeatedApplyChangeSet(t *testing.T) { }) require.NoError(t, err) + // Note: Multiple ApplyChangeSets calls are now allowed at DB level. + // The "one changeset per tree per version" validation is enforced by CommitStore. err = db.ApplyChangeSets([]*proto.NamedChangeSet{{Name: "test1"}}) - require.Error(t, err) + require.NoError(t, err) err = db.ApplyChangeSet("test1", iavl.ChangeSet{ Pairs: []*iavl.KVPair{ {Key: []byte("hello2"), Value: []byte("world2")}, }, }) - require.Error(t, err) + require.NoError(t, err) _, err = db.Commit() require.NoError(t, err) @@ -648,18 +878,20 @@ func TestRepeatedApplyChangeSet(t *testing.T) { }) require.NoError(t, err) + // Note: At DB level, multiple ApplyChangeSet calls with the same tree name are now allowed. + // The "one changeset per tree per version" validation is enforced by CommitStore. err = db.ApplyChangeSet("test1", iavl.ChangeSet{ Pairs: []*iavl.KVPair{ {Key: []byte("hello2"), Value: []byte("world2")}, }, }) - require.Error(t, err) + require.NoError(t, err) err = db.ApplyChangeSet("test2", iavl.ChangeSet{ Pairs: []*iavl.KVPair{ {Key: []byte("hello2"), Value: []byte("world2")}, }, }) - require.Error(t, err) + require.NoError(t, err) } func TestLoadMultiTreeWithCancelledContext(t *testing.T) { @@ -699,26 +931,32 @@ func TestLoadMultiTreeWithCancelledContext(t *testing.T) { func TestCatchupWithCancelledContext(t *testing.T) { // Create a DB with some data dir := t.TempDir() + initialStores := []string{"test"} + db, err := OpenDB(logger.NewNopLogger(), 0, Options{ Dir: dir, CreateIfMissing: true, - InitialStores: []string{"test"}, + InitialStores: initialStores, }) require.NoError(t, err) - defer db.Close() + t.Cleanup(func() { require.NoError(t, db.Close()) }) + + wal := db.GetWAL() + require.NotNil(t, wal) // Add multiple versions to have changelog entries for i := 0; i < 5; i++ { - require.NoError(t, db.ApplyChangeSets([]*proto.NamedChangeSet{ + cs := []*proto.NamedChangeSet{ {Name: "test", Changeset: iavl.ChangeSet{ Pairs: []*iavl.KVPair{{Key: []byte("key"), Value: []byte("value" + strconv.Itoa(i))}}, }}, - })) + } + require.NoError(t, db.ApplyChangeSets(cs)) _, err = db.Commit() require.NoError(t, err) } - // Create snapshot at version 2 + // Create snapshot at version 5 require.NoError(t, db.RewriteSnapshot(context.Background())) // Load the snapshot (at version 5) @@ -728,13 +966,13 @@ func TestCatchupWithCancelledContext(t *testing.T) { Logger: logger.NewNopLogger(), }) require.NoError(t, err) - defer mtree.Close() + t.Cleanup(func() { require.NoError(t, mtree.Close()) }) // Catchup with cancelled context should return error ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately - err = mtree.Catchup(ctx, db.streamHandler, 0) + err = mtree.Catchup(ctx, wal, db.walIndexDelta, 0) // If already caught up, no error; otherwise should get context.Canceled if err != nil { require.Equal(t, context.Canceled, err) diff --git a/sei-db/state_db/sc/memiavl/multitree.go b/sei-db/state_db/sc/memiavl/multitree.go index b12d5a20da..7a266d1ce7 100644 --- a/sei-db/state_db/sc/memiavl/multitree.go +++ b/sei-db/state_db/sc/memiavl/multitree.go @@ -12,13 +12,14 @@ import ( "time" "github.com/alitto/pond" - "github.com/sei-protocol/sei-chain/sei-db/changelog/types" + "golang.org/x/exp/slices" + "github.com/sei-protocol/sei-chain/sei-db/common/errors" "github.com/sei-protocol/sei-chain/sei-db/common/logger" "github.com/sei-protocol/sei-chain/sei-db/common/utils" "github.com/sei-protocol/sei-chain/sei-db/proto" + "github.com/sei-protocol/sei-chain/sei-db/wal" iavl "github.com/sei-protocol/sei-chain/sei-iavl" - "golang.org/x/exp/slices" ) const ( @@ -78,6 +79,9 @@ func NewEmptyMultiTree(initialVersion uint32) *MultiTree { func LoadMultiTree(ctx context.Context, dir string, opts Options) (*MultiTree, error) { startTime := time.Now() log := opts.Logger + if log == nil { + log = logger.NewNopLogger() + } metadata, err := readMetadata(dir) if err != nil { return nil, err @@ -354,49 +358,75 @@ func (t *MultiTree) UpdateCommitInfo() { t.lastCommitInfo = *t.buildCommitInfo(t.lastCommitInfo.Version) } -// Catchup replay the new entries in the Rlog file on the tree to catch up to the target or latest version. -func (t *MultiTree) Catchup(ctx context.Context, stream types.Stream[proto.ChangelogEntry], endVersion int64) error { +// Catchup replays WAL entries to catch up the tree to the target or latest version. +// delta is the difference between version and WAL index (version = walIndex + delta). +// endVersion specifies the target version (0 means catch up to latest). +func (t *MultiTree) Catchup(ctx context.Context, stream wal.ChangelogWAL, delta int64, endVersion int64) error { startTime := time.Now() + + // Get actual WAL index range + firstIndex, err := stream.FirstOffset() + if err != nil { + return fmt.Errorf("read rlog first index failed, %w", err) + } lastIndex, err := stream.LastOffset() if err != nil { return fmt.Errorf("read rlog last index failed, %w", err) } - iv := t.initialVersion.Load() - firstIndex := utils.VersionToIndex(utils.NextVersion(t.Version(), iv), iv) - if firstIndex > lastIndex { - // already up-to-date + // Empty WAL - nothing to replay + if lastIndex == 0 || firstIndex > lastIndex { return nil } - endIndex := lastIndex - if endVersion != 0 { - endIndex = utils.VersionToIndex(endVersion, iv) - } + currentVersion := t.Version() - if endIndex < firstIndex { - return fmt.Errorf("target index %d is pruned", endIndex) - } + // Calculate start index: walIndex = version - delta + // We want to start from currentVersion + 1 + startIndexSigned := currentVersion + 1 - delta - if endIndex > lastIndex { - return fmt.Errorf("target index %d is in the future, latest index: %d", endIndex, lastIndex) + // Ensure startIndex is within valid range (handle negative case before uint64 conversion) + var startIndex uint64 + if startIndexSigned <= 0 || uint64(startIndexSigned) < firstIndex { + startIndex = firstIndex + } else { + startIndex = uint64(startIndexSigned) + } + if startIndex > lastIndex { + // Nothing to replay - tree is already caught up + return nil } var replayCount = 0 - err = stream.Replay(firstIndex, endIndex, func(index uint64, entry proto.ChangelogEntry) error { + err = stream.Replay(startIndex, lastIndex, func(index uint64, entry proto.ChangelogEntry) error { // Check for cancellation select { case <-ctx.Done(): return ctx.Err() default: } + + // Safety check: skip entries we already have (should not happen with correct startIndex) + if entry.Version <= currentVersion { + return nil + } + + // If endVersion is specified, stop at that version + if endVersion != 0 && entry.Version > endVersion { + return nil + } + if err := t.ApplyUpgrades(entry.Upgrades); err != nil { return err } updatedTrees := make(map[string]bool) for _, cs := range entry.Changesets { treeName := cs.Name - t.TreeByName(treeName).ApplyChangeSetAsync(cs.Changeset) + tree := t.TreeByName(treeName) + if tree == nil { + return fmt.Errorf("unknown tree name %s during WAL replay (missing initial stores / upgrades)", treeName) + } + tree.ApplyChangeSetAsync(cs.Changeset) updatedTrees[treeName] = true } for _, tree := range t.trees { @@ -404,7 +434,7 @@ func (t *MultiTree) Catchup(ctx context.Context, stream types.Stream[proto.Chang tree.ApplyChangeSetAsync(iavl.ChangeSet{}) } } - t.lastCommitInfo.Version = utils.NextVersion(t.lastCommitInfo.Version, t.initialVersion.Load()) + t.lastCommitInfo.Version = entry.Version t.lastCommitInfo.StoreInfos = []proto.StoreInfo{} replayCount++ if replayCount%1000 == 0 { @@ -420,11 +450,14 @@ func (t *MultiTree) Catchup(ctx context.Context, stream types.Stream[proto.Chang if err != nil { return err } - t.UpdateCommitInfo() - replayElapsed := time.Since(startTime).Seconds() - t.logger.Info(fmt.Sprintf("Total replayed %d entries in %.1fs (%.1f entries/sec).\n", - replayCount, replayElapsed, float64(replayCount)/replayElapsed)) + if replayCount > 0 { + t.UpdateCommitInfo() + replayElapsed := time.Since(startTime).Seconds() + t.logger.Info(fmt.Sprintf("Total replayed %d entries in %.1fs (%.1f entries/sec).\n", + replayCount, replayElapsed, float64(replayCount)/replayElapsed)) + } + return nil } diff --git a/sei-db/state_db/sc/memiavl/opts.go b/sei-db/state_db/sc/memiavl/opts.go index ad6474890b..144af6bd6a 100644 --- a/sei-db/state_db/sc/memiavl/opts.go +++ b/sei-db/state_db/sc/memiavl/opts.go @@ -6,7 +6,6 @@ import ( "time" "github.com/sei-protocol/sei-chain/sei-db/common/logger" - "github.com/sei-protocol/sei-chain/sei-db/config" ) diff --git a/sei-db/state_db/sc/memiavl/proof_test.go b/sei-db/state_db/sc/memiavl/proof_test.go index ceaa2bb2f5..b31f7c8b5c 100644 --- a/sei-db/state_db/sc/memiavl/proof_test.go +++ b/sei-db/state_db/sc/memiavl/proof_test.go @@ -48,7 +48,7 @@ func TestProofs(t *testing.T) { snapshot, err := OpenSnapshot(tmpDir, opts) require.NoError(t, err) ptree := NewFromSnapshot(snapshot, opts) - defer func() { _ = ptree.Close() }() + t.Cleanup(func() { require.NoError(t, ptree.Close()) }) proof, err = ptree.GetMembershipProof(tc.existKey) require.NoError(t, err) diff --git a/sei-db/state_db/sc/memiavl/snapshot_methods_test.go b/sei-db/state_db/sc/memiavl/snapshot_methods_test.go index a63835845f..50eca5dfee 100644 --- a/sei-db/state_db/sc/memiavl/snapshot_methods_test.go +++ b/sei-db/state_db/sc/memiavl/snapshot_methods_test.go @@ -24,7 +24,7 @@ func TestSnapshotLeaf(t *testing.T) { opts.FillDefaults() snapshot, err := OpenSnapshot(snapshotDir, opts) require.NoError(t, err) - defer snapshot.Close() + t.Cleanup(func() { require.NoError(t, snapshot.Close()) }) // Test Leaf method if snapshot.leavesLen() > 0 { @@ -49,7 +49,7 @@ func TestSnapshotScanNodes(t *testing.T) { opts.FillDefaults() snapshot, err := OpenSnapshot(snapshotDir, opts) require.NoError(t, err) - defer snapshot.Close() + t.Cleanup(func() { require.NoError(t, snapshot.Close()) }) // Test ScanNodes count := 0 @@ -77,7 +77,7 @@ func TestSnapshotKey(t *testing.T) { opts.FillDefaults() snapshot, err := OpenSnapshot(snapshotDir, opts) require.NoError(t, err) - defer snapshot.Close() + t.Cleanup(func() { require.NoError(t, snapshot.Close()) }) // Test Key method via scanning leaves if snapshot.leavesLen() > 0 { @@ -126,7 +126,7 @@ func TestPrefetchSnapshot(t *testing.T) { snapshot, err := OpenSnapshot(snapshotDir, opts) require.NoError(t, err) - defer snapshot.Close() + t.Cleanup(func() { require.NoError(t, snapshot.Close()) }) require.NotNil(t, snapshot) } diff --git a/sei-db/state_db/sc/memiavl/snapshot_pipeline_test.go b/sei-db/state_db/sc/memiavl/snapshot_pipeline_test.go index 5e18fb5805..8b5cfbf4e1 100644 --- a/sei-db/state_db/sc/memiavl/snapshot_pipeline_test.go +++ b/sei-db/state_db/sc/memiavl/snapshot_pipeline_test.go @@ -30,7 +30,7 @@ func TestSnapshotWriterPipeline(t *testing.T) { opts.FillDefaults() snapshot, err := OpenSnapshot(snapshotDir, opts) require.NoError(t, err) - defer snapshot.Close() + t.Cleanup(func() { require.NoError(t, snapshot.Close()) }) require.Equal(t, uint32(tree.Version()), snapshot.Version()) require.Equal(t, tree.RootHash(), snapshot.RootHash()) @@ -222,7 +222,7 @@ func TestEmptySnapshotWrite(t *testing.T) { opts.FillDefaults() snapshot, err := OpenSnapshot(snapshotDir, opts) require.NoError(t, err) - defer snapshot.Close() + t.Cleanup(func() { require.NoError(t, snapshot.Close()) }) require.True(t, snapshot.IsEmpty()) require.Equal(t, uint32(0), snapshot.Version()) diff --git a/sei-db/state_db/sc/memiavl/snapshot_test.go b/sei-db/state_db/sc/memiavl/snapshot_test.go index 7db9ea1738..f5bb0015a7 100644 --- a/sei-db/state_db/sc/memiavl/snapshot_test.go +++ b/sei-db/state_db/sc/memiavl/snapshot_test.go @@ -141,10 +141,13 @@ func TestSnapshotImportExport(t *testing.T) { } func TestDBSnapshotRestore(t *testing.T) { + dir := t.TempDir() + initialStores := []string{"test", "test2"} + db, err := OpenDB(logger.NewNopLogger(), 0, Options{ - Dir: t.TempDir(), + Dir: dir, CreateIfMissing: true, - InitialStores: []string{"test", "test2"}, + InitialStores: initialStores, AsyncCommitBuffer: -1, }) require.NoError(t, err) @@ -163,17 +166,20 @@ func TestDBSnapshotRestore(t *testing.T) { require.NoError(t, db.ApplyChangeSets(cs)) _, err := db.Commit() require.NoError(t, err) + + // Create snapshot so export/import test can work without WAL + require.NoError(t, db.RewriteSnapshot(context.Background())) testSnapshotRoundTrip(t, db) } - require.NoError(t, db.RewriteSnapshot(context.Background())) require.NoError(t, db.Reload()) require.Equal(t, len(ChangeSets), int(db.metadata.CommitInfo.Version)) testSnapshotRoundTrip(t, db) } func testSnapshotRoundTrip(t *testing.T, db *DB) { - exporter, err := NewMultiTreeExporter(db.dir, uint32(db.Version()), false) + // Use NewMultiTreeExporter which loads from snapshot on disk + exporter, err := NewMultiTreeExporter(db.dir, uint32(db.Version()), true) // onlyAllowExportOnSnapshotVersion=true require.NoError(t, err) restoreDir := t.TempDir() diff --git a/sei-db/state_db/sc/memiavl/tree_test.go b/sei-db/state_db/sc/memiavl/tree_test.go index 04a9cd57e8..02f25e5c3a 100644 --- a/sei-db/state_db/sc/memiavl/tree_test.go +++ b/sei-db/state_db/sc/memiavl/tree_test.go @@ -259,7 +259,7 @@ func TestGetByIndex(t *testing.T) { snapshot, err := OpenSnapshot(dir, Options{}) require.NoError(t, err) ptree := NewFromSnapshot(snapshot, Options{ZeroCopy: true}) - defer func() { _ = ptree.Close() }() + t.Cleanup(func() { require.NoError(t, ptree.Close()) }) for i, pair := range changes.Pairs { idx, v := ptree.GetWithIndex(pair.Key) diff --git a/sei-db/state_db/sc/store.go b/sei-db/state_db/sc/store.go index 60127fb0c3..0807c5a0e8 100644 --- a/sei-db/state_db/sc/store.go +++ b/sei-db/state_db/sc/store.go @@ -5,6 +5,7 @@ import ( "math" "time" + "github.com/sei-protocol/sei-chain/sei-db/common/errors" "github.com/sei-protocol/sei-chain/sei-db/common/logger" "github.com/sei-protocol/sei-chain/sei-db/common/utils" "github.com/sei-protocol/sei-chain/sei-db/config" @@ -16,9 +17,11 @@ import ( var _ types.Committer = (*CommitStore)(nil) type CommitStore struct { - logger logger.Logger - db *memiavl.DB - opts memiavl.Options + logger logger.Logger + db *memiavl.DB + opts memiavl.Options + homeDir string + cfg config.StateCommitConfig } func NewCommitStore(homeDir string, logger logger.Logger, config config.StateCommitConfig) *CommitStore { @@ -26,8 +29,9 @@ func NewCommitStore(homeDir string, logger logger.Logger, config config.StateCom if config.Directory != "" { scDir = config.Directory } + commitDBPath := utils.GetCommitStorePath(scDir) opts := memiavl.Options{ - Dir: utils.GetCommitStorePath(scDir), + Dir: commitDBPath, ZeroCopy: config.ZeroCopy, AsyncCommitBuffer: config.AsyncCommitBuffer, SnapshotInterval: config.SnapshotInterval, @@ -39,8 +43,10 @@ func NewCommitStore(homeDir string, logger logger.Logger, config config.StateCom OnlyAllowExportOnSnapshotVersion: config.OnlyAllowExportOnSnapshotVersion, } commitStore := &CommitStore{ - logger: logger, - opts: opts, + logger: logger, + opts: opts, + homeDir: homeDir, + cfg: config, } return commitStore } @@ -54,11 +60,14 @@ func (cs *CommitStore) SetInitialVersion(initialVersion int64) error { } func (cs *CommitStore) Rollback(targetVersion int64) error { - options := cs.opts - options.LoadForOverwriting = true + // Close existing resources if cs.db != nil { _ = cs.db.Close() } + + options := cs.opts + options.LoadForOverwriting = true + db, err := memiavl.OpenDB(cs.logger, targetVersion, options) if err != nil { return err @@ -67,30 +76,37 @@ func (cs *CommitStore) Rollback(targetVersion int64) error { return nil } -// copyExisting is for creating new memiavl object given existing folder +// LoadVersion loads the specified version of the database. +// If copyExisting is true, creates a read-only copy for querying. func (cs *CommitStore) LoadVersion(targetVersion int64, copyExisting bool) (types.Committer, error) { cs.logger.Info(fmt.Sprintf("SeiDB load target memIAVL version %d, copyExisting = %v\n", targetVersion, copyExisting)) + if copyExisting { - opts := cs.opts - opts.ReadOnly = copyExisting - opts.CreateIfMissing = false - db, err := memiavl.OpenDB(cs.logger, targetVersion, opts) + // Create a read-only copy via NewCommitStore. + newCS := NewCommitStore(cs.homeDir, cs.logger, cs.cfg) + newCS.opts = cs.opts + newCS.opts.ReadOnly = true + newCS.opts.CreateIfMissing = false + + db, err := memiavl.OpenDB(cs.logger, targetVersion, newCS.opts) if err != nil { return nil, err } - return &CommitStore{ - logger: cs.logger, - db: db, - opts: opts, - }, nil + newCS.db = db + return newCS, nil } + + // Close existing resources if cs.db != nil { _ = cs.db.Close() } - db, err := memiavl.OpenDB(cs.logger, targetVersion, cs.opts) + + opts := cs.opts + db, err := memiavl.OpenDB(cs.logger, targetVersion, opts) if err != nil { return nil, err } + cs.db = db return cs, nil } @@ -112,10 +128,20 @@ func (cs *CommitStore) GetEarliestVersion() (int64, error) { } func (cs *CommitStore) ApplyChangeSets(changesets []*proto.NamedChangeSet) error { + if len(changesets) == 0 { + return nil + } + + // Apply to tree return cs.db.ApplyChangeSets(changesets) } func (cs *CommitStore) ApplyUpgrades(upgrades []*proto.TreeNameUpgrade) error { + if len(upgrades) == 0 { + return nil + } + + // Apply to tree return cs.db.ApplyUpgrades(upgrades) } @@ -135,27 +161,23 @@ func (cs *CommitStore) Exporter(version int64) (types.Exporter, error) { if version < 0 || version > math.MaxUint32 { return nil, fmt.Errorf("version %d out of range", version) } - - exporter, err := memiavl.NewMultiTreeExporter(cs.opts.Dir, uint32(version), cs.opts.OnlyAllowExportOnSnapshotVersion) - if err != nil { - return nil, err - } - return exporter, nil + return memiavl.NewMultiTreeExporter(cs.opts.Dir, uint32(version), cs.opts.OnlyAllowExportOnSnapshotVersion) } func (cs *CommitStore) Importer(version int64) (types.Importer, error) { - if version < 0 || version > math.MaxUint32 { return nil, fmt.Errorf("version %d out of range", version) } - - treeImporter, err := memiavl.NewMultiTreeImporter(cs.opts.Dir, uint64(version)) - if err != nil { - return nil, err - } - return treeImporter, nil + return memiavl.NewMultiTreeImporter(cs.opts.Dir, uint64(version)) } func (cs *CommitStore) Close() error { - return cs.db.Close() + var errs []error + + if cs.db != nil { + errs = append(errs, cs.db.Close()) + cs.db = nil + } + + return errors.Join(errs...) } diff --git a/sei-db/state_db/sc/store_test.go b/sei-db/state_db/sc/store_test.go new file mode 100644 index 0000000000..812e8f58c6 --- /dev/null +++ b/sei-db/state_db/sc/store_test.go @@ -0,0 +1,1035 @@ +package sc + +import ( + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/sei-protocol/sei-chain/sei-db/common/logger" + "github.com/sei-protocol/sei-chain/sei-db/config" + "github.com/sei-protocol/sei-chain/sei-db/proto" + iavl "github.com/sei-protocol/sei-chain/sei-iavl" +) + +func mustReadLastChangelogEntry(t *testing.T, cs *CommitStore) proto.ChangelogEntry { + t.Helper() + require.NotNil(t, cs.db) + w := cs.db.GetWAL() + require.NotNil(t, w) + last, err := w.LastOffset() + require.NoError(t, err) + require.Greater(t, last, uint64(0)) + e, err := w.ReadAt(last) + require.NoError(t, err) + return e +} + +func TestNewCommitStore(t *testing.T) { + dir := t.TempDir() + cfg := config.StateCommitConfig{ + ZeroCopy: true, + SnapshotInterval: 10, + } + + cs := NewCommitStore(dir, logger.NewNopLogger(), cfg) + require.NotNil(t, cs) + require.NotNil(t, cs.logger) + require.True(t, cs.opts.ZeroCopy) + require.Equal(t, uint32(10), cs.opts.SnapshotInterval) + require.True(t, cs.opts.CreateIfMissing) +} + +func TestNewCommitStoreWithCustomDirectory(t *testing.T) { + homeDir := t.TempDir() + customDir := t.TempDir() + cfg := config.StateCommitConfig{ + Directory: customDir, + } + + cs := NewCommitStore(homeDir, logger.NewNopLogger(), cfg) + require.NotNil(t, cs) + require.Contains(t, cs.opts.Dir, customDir) +} + +func TestInitialize(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + + stores := []string{"store1", "store2", "store3"} + cs.Initialize(stores) + + require.Equal(t, stores, cs.opts.InitialStores) +} + +func TestCommitStoreBasicOperations(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + // Load version 0 to initialize the DB + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + defer func() { + err := cs.Close() + require.NoError(t, err) + }() + + // Initial version should be 0 + require.Equal(t, int64(0), cs.Version()) + + // Apply changesets + changesets := []*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key1"), Value: []byte("value1")}, + {Key: []byte("key2"), Value: []byte("value2")}, + }, + }, + }, + } + err = cs.ApplyChangeSets(changesets) + require.NoError(t, err) + + // Commit + version, err := cs.Commit() + require.NoError(t, err) + require.Equal(t, int64(1), version) + + entry := mustReadLastChangelogEntry(t, cs) + require.Equal(t, int64(1), entry.Version) + require.Equal(t, changesets, entry.Changesets) + + // Version should be updated + require.Equal(t, int64(1), cs.Version()) +} + +func TestApplyChangeSetsEmpty(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + defer func() { + err := cs.Close() + require.NoError(t, err) + }() + + // Empty changesets should be no-op + err = cs.ApplyChangeSets(nil) + require.NoError(t, err) + + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{}) + require.NoError(t, err) +} + +func TestApplyUpgrades(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + defer func() { + err := cs.Close() + require.NoError(t, err) + }() + + // Apply upgrades + upgrades := []*proto.TreeNameUpgrade{ + {Name: "newstore1"}, + {Name: "newstore2"}, + } + err = cs.ApplyUpgrades(upgrades) + require.NoError(t, err) + + // Apply more upgrades - should append + moreUpgrades := []*proto.TreeNameUpgrade{ + {Name: "newstore3"}, + } + err = cs.ApplyUpgrades(moreUpgrades) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + entry := mustReadLastChangelogEntry(t, cs) + // 4 upgrades total: initial store "test" + newstore1, newstore2, newstore3 + require.Len(t, entry.Upgrades, 4) +} + +func TestApplyUpgradesEmpty(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + defer func() { + err := cs.Close() + require.NoError(t, err) + }() + + // Empty upgrades should be no-op + err = cs.ApplyUpgrades(nil) + require.NoError(t, err) + + err = cs.ApplyUpgrades([]*proto.TreeNameUpgrade{}) + require.NoError(t, err) +} + +func TestLoadVersionCopyExisting(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + // First load to create the DB + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + // Apply and commit some data + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value")}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + require.NoError(t, cs.Close()) + + // Load with copyExisting=true should create a new readonly CommitStore + newCS, err := cs.LoadVersion(0, true) + require.NoError(t, err) + require.NotNil(t, newCS) + + // The returned store should be different from the original + newCommitStore, ok := newCS.(*CommitStore) + require.True(t, ok) + require.NotSame(t, cs, newCommitStore) + + require.NoError(t, newCommitStore.Close()) +} + +func TestCommitInfo(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + defer func() { + err := cs.Close() + require.NoError(t, err) + }() + + // WorkingCommitInfo before any commit + workingInfo := cs.WorkingCommitInfo() + require.NotNil(t, workingInfo) + + // Apply and commit + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value")}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + + // LastCommitInfo after commit + lastInfo := cs.LastCommitInfo() + require.NotNil(t, lastInfo) + require.Equal(t, int64(1), lastInfo.Version) +} + +func TestGetModuleByName(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test", "other"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + defer func() { + err := cs.Close() + require.NoError(t, err) + }() + + // Get existing module + module := cs.GetModuleByName("test") + require.NotNil(t, module) + + // Get non-existing module + module = cs.GetModuleByName("nonexistent") + require.Nil(t, module) +} + +func TestExporterVersionValidation(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + defer func() { + err := cs.Close() + require.NoError(t, err) + }() + + // Negative version should fail + _, err = cs.Exporter(-1) + require.Error(t, err) + require.Contains(t, err.Error(), "out of range") + + // Version > MaxUint32 should fail + _, err = cs.Exporter(math.MaxUint32 + 1) + require.Error(t, err) + require.Contains(t, err.Error(), "out of range") +} + +func TestImporterVersionValidation(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + + // Negative version should fail + _, err := cs.Importer(-1) + require.Error(t, err) + require.Contains(t, err.Error(), "out of range") + + // Version > MaxUint32 should fail + _, err = cs.Importer(math.MaxUint32 + 1) + require.Error(t, err) + require.Contains(t, err.Error(), "out of range") +} + +func TestClose(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + // Close should succeed + err = cs.Close() + require.NoError(t, err) + + // db should be nil after close + require.Nil(t, cs.db) + + // Close again should be safe (no-op) + err = cs.Close() + require.NoError(t, err) +} + +func TestRollback(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + // Commit a few versions + for i := 0; i < 3; i++ { + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value" + string(rune('0'+i)))}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + } + + require.Equal(t, int64(3), cs.Version()) + + // Rollback to version 2 (truncates WAL after version 2) + err = cs.Rollback(2) + require.NoError(t, err) + require.Equal(t, int64(2), cs.Version()) + + require.NoError(t, cs.Close()) +} + +func TestMultipleCommits(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + defer func() { + err := cs.Close() + require.NoError(t, err) + }() + + // Multiple commits + for i := 1; i <= 5; i++ { + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key" + string(rune('0'+i))), Value: []byte("value")}, + }, + }, + }, + }) + require.NoError(t, err) + + version, err := cs.Commit() + require.NoError(t, err) + require.Equal(t, int64(i), version) + } + + require.Equal(t, int64(5), cs.Version()) +} + +func TestCommitWithUpgradesAndChangesets(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + defer func() { + err := cs.Close() + require.NoError(t, err) + }() + + // Apply upgrades first + err = cs.ApplyUpgrades([]*proto.TreeNameUpgrade{ + {Name: "newstore"}, + }) + require.NoError(t, err) + + // Then apply changesets to the new store + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "newstore", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value")}, + }, + }, + }, + }) + require.NoError(t, err) + + // Commit + version, err := cs.Commit() + require.NoError(t, err) + require.Equal(t, int64(1), version) + entry := mustReadLastChangelogEntry(t, cs) + // 2 upgrades total: initial store "test" + "newstore" + require.Len(t, entry.Upgrades, 2) + require.Len(t, entry.Changesets, 1) +} + +func TestSetInitialVersion(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + defer func() { + err := cs.Close() + require.NoError(t, err) + }() + + // Set initial version + err = cs.SetInitialVersion(100) + require.NoError(t, err) +} + +func TestGetVersions(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + // Commit a few versions + for i := 0; i < 3; i++ { + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value")}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + } + require.NoError(t, cs.Close()) + + // Create new CommitStore to test GetLatestVersion + cs2 := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs2.Initialize([]string{"test"}) + + latestVersion, err := cs2.GetLatestVersion() + require.NoError(t, err) + require.Equal(t, int64(3), latestVersion) +} + +func TestCreateWAL(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + defer func() { + err := cs.Close() + require.NoError(t, err) + }() + + // MemIAVL should have opened its changelog WAL. + require.NotNil(t, cs.db.GetWAL()) +} + +func TestLoadVersionReadOnlyWithWALReplay(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + // First load to create the DB + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + // Write data (MemIAVL will persist changelog internally) + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key1"), Value: []byte("value1")}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + + // Write more data + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key2"), Value: []byte("value2")}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + + require.Equal(t, int64(2), cs.Version()) + + // Load read-only copy - should replay from WAL + readOnlyCS, err := cs.LoadVersion(0, true) + require.NoError(t, err) + require.NotNil(t, readOnlyCS) + + // The read-only copy should have the same version after WAL replay + roCommitStore := readOnlyCS.(*CommitStore) + require.Equal(t, int64(2), roCommitStore.Version()) + + require.NotNil(t, roCommitStore.db.GetWAL()) + + // Clean up + require.NoError(t, roCommitStore.Close()) + require.NoError(t, cs.Close()) +} + +func TestLoadVersionReadOnlyCreatesOwnWAL(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + // First load to create the DB + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + // Commit some data with WAL entries + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value")}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + + // Create multiple read-only copies + readOnly1, err := cs.LoadVersion(0, true) + require.NoError(t, err) + require.NotNil(t, readOnly1) + + readOnly2, err := cs.LoadVersion(0, true) + require.NoError(t, err) + require.NotNil(t, readOnly2) + + // Each should have its own WAL instance + ro1 := readOnly1.(*CommitStore) + ro2 := readOnly2.(*CommitStore) + require.NotNil(t, ro1.db.GetWAL()) + require.NotNil(t, ro2.db.GetWAL()) + + // Clean up + require.NoError(t, ro1.Close()) + require.NoError(t, ro2.Close()) + require.NoError(t, cs.Close()) +} + +func TestWALPersistenceAcrossRestart(t *testing.T) { + dir := t.TempDir() + + // First session: write data + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + // Write and commit + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key1"), Value: []byte("value1")}, + {Key: []byte("key2"), Value: []byte("value2")}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + + // More commits + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key3"), Value: []byte("value3")}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + + require.Equal(t, int64(2), cs.Version()) + require.NoError(t, cs.Close()) + + // Second session: reload and verify WAL replay + cs2 := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs2.Initialize([]string{"test"}) + + _, err = cs2.LoadVersion(0, false) + require.NoError(t, err) + + // Version should be restored via WAL replay + require.Equal(t, int64(2), cs2.Version()) + + // Data should be accessible + tree := cs2.GetModuleByName("test") + require.NotNil(t, tree) + + require.NoError(t, cs2.Close()) +} + +func TestRollbackWithWAL(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + // Commit multiple versions + for i := 0; i < 5; i++ { + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value" + string(rune('0'+i)))}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + } + + require.Equal(t, int64(5), cs.Version()) + require.NotNil(t, cs.db.GetWAL()) + + // Rollback to version 3 + err = cs.Rollback(3) + require.NoError(t, err) + require.Equal(t, int64(3), cs.Version()) + + // WAL should still exist after rollback + require.NotNil(t, cs.db.GetWAL()) + + require.NoError(t, cs.Close()) + + // Reopen and verify rollback persisted + cs2 := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs2.Initialize([]string{"test"}) + + _, err = cs2.LoadVersion(0, false) + require.NoError(t, err) + + // Version should be 3 after replay + require.Equal(t, int64(3), cs2.Version()) + + require.NoError(t, cs2.Close()) +} + +func TestRollbackCreatesWALIfNeeded(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + // Load and commit + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value")}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + + // Close to clear WAL + require.NoError(t, cs.Close()) + + // After Close(), create a new CommitStore (WAL creation happens in NewCommitStore) + cs2 := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs2.Initialize([]string{"test"}) + + // Rollback should work + require.NoError(t, cs2.Rollback(1)) + require.NoError(t, cs2.Close()) +} + +func TestCloseReleasesWAL(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + require.NotNil(t, cs.db) + require.NotNil(t, cs.db.GetWAL()) + + // Close + require.NoError(t, cs.Close()) + + // DB should be nil after close + require.Nil(t, cs.db) +} + +func TestLoadVersionReusesExistingWAL(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + // First load + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + require.NotNil(t, cs.db.GetWAL()) + + // Commit some data + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value")}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + + // Second load (non-copy) should close and recreate WAL + _, err = cs.LoadVersion(0, false) + require.NoError(t, err) + + require.NotNil(t, cs.db.GetWAL()) + + // Version should be replayed + require.Equal(t, int64(1), cs.Version()) + + require.NoError(t, cs.Close()) +} + +func TestReadOnlyCopyCannotCommit(t *testing.T) { + dir := t.TempDir() + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + // First load + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + // Commit initial data + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value")}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + + // Load read-only copy + readOnly, err := cs.LoadVersion(0, true) + require.NoError(t, err) + + roCS := readOnly.(*CommitStore) + + // Read-only copy should have read-only option set + require.True(t, roCS.opts.ReadOnly) + + // Attempting to commit on read-only copy should fail + // (this would fail at the memiavl.DB level) + _, err = roCS.Commit() + require.Error(t, err) + + require.NoError(t, roCS.Close()) + require.NoError(t, cs.Close()) +} + +// TestWALTruncationOnCommit tests that WAL is automatically truncated after commits +// when the earliest snapshot version advances past WAL entries. +func TestWALTruncationOnCommit(t *testing.T) { + dir := t.TempDir() + + // Configure with snapshot interval to trigger snapshot creation + cfg := config.StateCommitConfig{ + SnapshotInterval: 2, // Create snapshot every 2 blocks + SnapshotKeepRecent: 1, // Keep only 1 recent snapshot + } + cs := NewCommitStore(dir, logger.NewNopLogger(), cfg) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + // Commit multiple versions to trigger snapshot creation and WAL truncation + for i := 0; i < 10; i++ { + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value" + string(rune('0'+i)))}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + } + + // Verify current version + require.Equal(t, int64(10), cs.Version()) + + // Get WAL state + firstWALIndex, err := cs.db.GetWAL().FirstOffset() + require.NoError(t, err) + + // Get earliest snapshot version - may not exist yet if snapshots are async + earliestSnapshot, err := cs.GetEarliestVersion() + if err != nil { + // No snapshots yet (async snapshot creation), that's okay for this test + t.Logf("No snapshots created yet (async): %v", err) + require.NoError(t, cs.Close()) + return + } + + // WAL's first index should be greater than 1 if truncation happened + // (meaning early entries were removed) + // The exact value depends on snapshot creation timing and pruning + t.Logf("WAL first index: %d, earliest snapshot: %d", firstWALIndex, earliestSnapshot) + + // Key assertion: WAL entries before earliest snapshot should be truncated + // WAL version = index + delta, so WAL first version = firstIndex + delta + walDelta := cs.db.GetWALIndexDelta() + walFirstVersion := int64(firstWALIndex) + walDelta + require.GreaterOrEqual(t, walFirstVersion, earliestSnapshot, + "WAL first version should be >= earliest snapshot version after truncation") + + require.NoError(t, cs.Close()) +} + +// TestWALTruncationWithNoSnapshots tests that WAL truncation handles the case +// when no snapshots exist yet (should not panic or error). +func TestWALTruncationWithNoSnapshots(t *testing.T) { + dir := t.TempDir() + + // No snapshot interval configured, so no snapshots will be created + cs := NewCommitStore(dir, logger.NewNopLogger(), config.StateCommitConfig{}) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + // Commit a version + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value")}, + }, + }, + }, + }) + require.NoError(t, err) + + // Commit should succeed even though no snapshots exist + // (tryTruncateWAL should handle this gracefully) + _, err = cs.Commit() + require.NoError(t, err) + + // WAL should still have entries + firstIndex, err := cs.db.GetWAL().FirstOffset() + require.NoError(t, err) + require.Equal(t, uint64(1), firstIndex, "WAL should not be truncated when no snapshots exist") + + require.NoError(t, cs.Close()) +} + +// TestWALTruncationDelta tests that WAL truncation correctly uses the delta +// for version-to-index conversion with non-zero initial version. +func TestWALTruncationDelta(t *testing.T) { + dir := t.TempDir() + + cfg := config.StateCommitConfig{ + SnapshotInterval: 2, + SnapshotKeepRecent: 1, + } + cs := NewCommitStore(dir, logger.NewNopLogger(), cfg) + cs.Initialize([]string{"test"}) + + _, err := cs.LoadVersion(0, false) + require.NoError(t, err) + + // Set initial version to 100 + err = cs.SetInitialVersion(100) + require.NoError(t, err) + + // Commit multiple versions + for i := 0; i < 10; i++ { + err = cs.ApplyChangeSets([]*proto.NamedChangeSet{ + { + Name: "test", + Changeset: iavl.ChangeSet{ + Pairs: []*iavl.KVPair{ + {Key: []byte("key"), Value: []byte("value" + string(rune('0'+i)))}, + }, + }, + }, + }) + require.NoError(t, err) + _, err = cs.Commit() + require.NoError(t, err) + } + + // Verify version (should be 100 + 9 = 109) + require.Equal(t, int64(109), cs.Version()) + + // Close and reopen to verify delta is computed correctly from WAL + require.NoError(t, cs.Close()) + + // Reopen + cs2 := NewCommitStore(dir, logger.NewNopLogger(), cfg) + cs2.Initialize([]string{"test"}) + _, err = cs2.LoadVersion(0, false) + require.NoError(t, err) + + // Now verify delta is correct (computed from WAL entries) + walDelta := cs2.db.GetWALIndexDelta() + require.Equal(t, int64(99), walDelta, "Delta should be 99 (firstVersion 100 - firstIndex 1)") + + // Verify WAL truncation respects delta + firstWALIndex, err := cs2.db.GetWAL().FirstOffset() + require.NoError(t, err) + + // Get earliest snapshot version - may not exist yet if snapshots are async + earliestSnapshot, err := cs2.GetEarliestVersion() + if err != nil { + t.Logf("No snapshots created yet: %v", err) + require.NoError(t, cs2.Close()) + return + } + + walFirstVersion := int64(firstWALIndex) + walDelta + t.Logf("WAL first index: %d, WAL first version: %d, earliest snapshot: %d", + firstWALIndex, walFirstVersion, earliestSnapshot) + + require.GreaterOrEqual(t, walFirstVersion, earliestSnapshot, + "WAL first version should be >= earliest snapshot version") + + require.NoError(t, cs2.Close()) +} diff --git a/sei-db/state_db/ss/store.go b/sei-db/state_db/ss/store.go index 26a2cba9ac..83d3f7eda1 100644 --- a/sei-db/state_db/ss/store.go +++ b/sei-db/state_db/ss/store.go @@ -3,13 +3,13 @@ package ss import ( "fmt" - "github.com/sei-protocol/sei-chain/sei-db/changelog/changelog" "github.com/sei-protocol/sei-chain/sei-db/common/logger" "github.com/sei-protocol/sei-chain/sei-db/common/utils" "github.com/sei-protocol/sei-chain/sei-db/config" "github.com/sei-protocol/sei-chain/sei-db/proto" "github.com/sei-protocol/sei-chain/sei-db/state_db/ss/pruning" "github.com/sei-protocol/sei-chain/sei-db/state_db/ss/types" + "github.com/sei-protocol/sei-chain/sei-db/wal" ) type BackendType string @@ -61,7 +61,7 @@ func NewStateStore(logger logger.Logger, homeDir string, ssConfig config.StateSt func RecoverStateStore(logger logger.Logger, changelogPath string, stateStore types.StateStore) error { ssLatestVersion := stateStore.GetLatestVersion() logger.Info(fmt.Sprintf("Recovering from changelog %s with latest SS version %d", changelogPath, ssLatestVersion)) - streamHandler, err := changelog.NewStream(logger, changelogPath, changelog.Config{}) + streamHandler, err := wal.NewChangelogWAL(logger, changelogPath, wal.Config{}) if err != nil { return err } diff --git a/sei-db/tools/cmd/seidb/operations/replay_changelog.go b/sei-db/tools/cmd/seidb/operations/replay_changelog.go index 77d3c89ac9..d9f4526c91 100644 --- a/sei-db/tools/cmd/seidb/operations/replay_changelog.go +++ b/sei-db/tools/cmd/seidb/operations/replay_changelog.go @@ -4,13 +4,14 @@ import ( "fmt" "path/filepath" - "github.com/sei-protocol/sei-chain/sei-db/changelog/changelog" + "github.com/spf13/cobra" + "github.com/sei-protocol/sei-chain/sei-db/common/logger" "github.com/sei-protocol/sei-chain/sei-db/config" "github.com/sei-protocol/sei-chain/sei-db/proto" "github.com/sei-protocol/sei-chain/sei-db/state_db/ss" "github.com/sei-protocol/sei-chain/sei-db/state_db/ss/types" - "github.com/spf13/cobra" + "github.com/sei-protocol/sei-chain/sei-db/wal" ) var ssStore types.StateStore @@ -41,7 +42,7 @@ func executeReplayChangelog(cmd *cobra.Command, _ []string) { } logDir := filepath.Join(dbDir, "changelog") - stream, err := changelog.NewStream(logger.NewNopLogger(), logDir, changelog.Config{}) + stream, err := wal.NewChangelogWAL(logger.NewNopLogger(), logDir, wal.Config{}) if err != nil { panic(err) } diff --git a/sei-db/wal/changelog.go b/sei-db/wal/changelog.go new file mode 100644 index 0000000000..b9a44e6e58 --- /dev/null +++ b/sei-db/wal/changelog.go @@ -0,0 +1,25 @@ +package wal + +import ( + "github.com/sei-protocol/sei-chain/sei-db/common/logger" + "github.com/sei-protocol/sei-chain/sei-db/proto" +) + +// ChangelogWAL is a type alias for a WAL specialized for ChangelogEntry. +type ChangelogWAL = GenericWAL[proto.ChangelogEntry] + +// NewChangelogWAL creates a new WAL for ChangelogEntry. +// This is a convenience wrapper that handles serialization automatically. +func NewChangelogWAL(logger logger.Logger, dir string, config Config) (ChangelogWAL, error) { + return NewWAL( + func(e proto.ChangelogEntry) ([]byte, error) { return e.Marshal() }, + func(data []byte) (proto.ChangelogEntry, error) { + var e proto.ChangelogEntry + err := e.Unmarshal(data) + return e, err + }, + logger, + dir, + config, + ) +} diff --git a/sei-db/changelog/types/types.go b/sei-db/wal/types.go similarity index 65% rename from sei-db/changelog/types/types.go rename to sei-db/wal/types.go index 597b605c25..be5584a6b2 100644 --- a/sei-db/changelog/types/types.go +++ b/sei-db/wal/types.go @@ -1,11 +1,15 @@ -package types +package wal -type Stream[T any] interface { - // Write will write a new entry to the log at the given index. - Write(offset uint64, entry T) error +// MarshalFn is a function that serializes an entry to bytes. +type MarshalFn[T any] func(entry T) ([]byte, error) - // CheckError check the error signal of async writes - CheckError() error +// UnmarshalFn is a function that deserializes bytes to an entry. +type UnmarshalFn[T any] func(data []byte) (T, error) + +// GenericWAL is a generic write-ahead log interface. +type GenericWAL[T any] interface { + // Write will append a new entry to the end of the log. + Write(entry T) error // TruncateBefore will remove all entries that are before the provided `offset` TruncateBefore(offset uint64) error @@ -14,7 +18,7 @@ type Stream[T any] interface { TruncateAfter(offset uint64) error // ReadAt will read the replay log at the given index - ReadAt(offset uint64) (*T, error) + ReadAt(offset uint64) (T, error) // FirstOffset returns the first written index of the log FirstOffset() (offset uint64, err error) @@ -28,7 +32,7 @@ type Stream[T any] interface { Close() error } -type Subscriber[T any] interface { +type GenericWALProcessor[T any] interface { // Start starts the subscriber processing goroutine Start() diff --git a/sei-db/changelog/changelog/utils.go b/sei-db/wal/utils.go similarity index 94% rename from sei-db/changelog/changelog/utils.go rename to sei-db/wal/utils.go index 2500e57dd1..d374f3a9e0 100644 --- a/sei-db/changelog/changelog/utils.go +++ b/sei-db/wal/utils.go @@ -1,4 +1,4 @@ -package changelog +package wal import ( "bytes" @@ -9,9 +9,10 @@ import ( "path/filepath" "unsafe" - iavl "github.com/sei-protocol/sei-chain/sei-iavl" "github.com/tidwall/gjson" "github.com/tidwall/wal" + + iavl "github.com/sei-protocol/sei-chain/sei-iavl" ) func LogPath(dir string) string { @@ -90,16 +91,16 @@ func loadNextBinaryEntry(data []byte) (n int, err error) { return n + size, nil } -func channelBatchRecv[T any](ch <-chan *T) []*T { +func channelBatchRecv[T any](ch <-chan T) []T { // block if channel is empty - item := <-ch - if item == nil { + item, ok := <-ch + if !ok { // channel is closed return nil } remaining := len(ch) - result := make([]*T, 0, remaining+1) + result := make([]T, 0, remaining+1) result = append(result, item) for i := 0; i < remaining; i++ { result = append(result, <-ch) diff --git a/sei-db/wal/wal.go b/sei-db/wal/wal.go new file mode 100644 index 0000000000..a1cd79ff84 --- /dev/null +++ b/sei-db/wal/wal.go @@ -0,0 +1,327 @@ +package wal + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/tidwall/wal" + + "github.com/sei-protocol/sei-chain/sei-db/common/logger" +) + +// WAL is a generic write-ahead log implementation. +type WAL[T any] struct { + dir string + log *wal.Log + config Config + logger logger.Logger + marshal MarshalFn[T] + unmarshal UnmarshalFn[T] + writeChannel chan T + mtx sync.RWMutex // guards WAL state: lazy init/close of writeChannel, isClosed checks + asyncWriteErrCh chan error // buffered=1; async writer reports first error non-blocking + isClosed bool + closeCh chan struct{} // signals shutdown to background goroutines + wg sync.WaitGroup // tracks background goroutines (pruning) +} + +type Config struct { + WriteBufferSize int + KeepRecent uint64 + PruneInterval time.Duration +} + +// NewWAL creates a new generic write-ahead log that persists entries. +// marshal and unmarshal functions are used to serialize/deserialize entries. +// Example: +// +// NewWAL( +// func(e proto.ChangelogEntry) ([]byte, error) { return e.Marshal() }, +// func(data []byte) (proto.ChangelogEntry, error) { +// var e proto.ChangelogEntry +// err := e.Unmarshal(data) +// return e, err +// }, +// logger, dir, config, +// ) +func NewWAL[T any]( + marshal MarshalFn[T], + unmarshal UnmarshalFn[T], + logger logger.Logger, + dir string, + config Config, +) (*WAL[T], error) { + log, err := open(dir, &wal.Options{ + NoSync: true, + NoCopy: true, + }) + if err != nil { + return nil, err + } + w := &WAL[T]{ + dir: dir, + log: log, + config: config, + logger: logger, + marshal: marshal, + unmarshal: unmarshal, + closeCh: make(chan struct{}), + asyncWriteErrCh: make(chan error, 1), + } + + // Start the auto pruning goroutine + if config.KeepRecent > 0 && config.PruneInterval > 0 { + w.startPruning(config.KeepRecent, config.PruneInterval) + } + return w, nil + +} + +// Write will append a new entry to the end of the log. +// Whether the writes is in blocking or async manner depends on the buffer size. +// For async writes, this also checks for any previous async write errors. +func (walLog *WAL[T]) Write(entry T) error { + // Never hold walLog.mtx while doing a potentially-blocking send. Close() may run concurrently. + walLog.mtx.Lock() + defer walLog.mtx.Unlock() + if walLog.isClosed { + return errors.New("wal is closed") + } + if err := walLog.getAsyncWriteErrLocked(); err != nil { + return fmt.Errorf("async WAL write failed previously: %w", err) + } + writeBufferSize := walLog.config.WriteBufferSize + if writeBufferSize > 0 { + if walLog.writeChannel == nil { + walLog.writeChannel = make(chan T, writeBufferSize) + walLog.startAsyncWriteGoroutine() + walLog.logger.Info(fmt.Sprintf("WAL async write is enabled with buffer size %d", writeBufferSize)) + } + walLog.writeChannel <- entry + } else { + // synchronous write + bz, err := walLog.marshal(entry) + if err != nil { + return err + } + lastOffset, err := walLog.log.LastIndex() + if err != nil { + return err + } + if err := walLog.log.Write(lastOffset+1, bz); err != nil { + return err + } + } + return nil +} + +// startWriteGoroutine will start a goroutine to write entries to the log. +// This should only be called on initialization if async write is enabled +func (walLog *WAL[T]) startAsyncWriteGoroutine() { + walLog.wg.Add(1) + ch := walLog.writeChannel + go func() { + defer walLog.wg.Done() + for entry := range ch { + bz, err := walLog.marshal(entry) + if err != nil { + walLog.recordAsyncWriteErr(err) + return + } + nextOffset, err := walLog.NextOffset() + if err != nil { + walLog.recordAsyncWriteErr(err) + return + } + err = walLog.log.Write(nextOffset, bz) + if err != nil { + walLog.recordAsyncWriteErr(err) + return + } + + } + }() +} + +// TruncateAfter will remove all entries that are after the provided `index`. +// In other words the entry at `index` becomes the last entry in the log. +func (walLog *WAL[T]) TruncateAfter(index uint64) error { + return walLog.log.TruncateBack(index) +} + +// TruncateBefore will remove all entries that are before the provided `index`. +// In other words the entry at `index` becomes the first entry in the log. +// Need to add write lock because this would change the next write offset +func (walLog *WAL[T]) TruncateBefore(index uint64) error { + return walLog.log.TruncateFront(index) +} + +func (walLog *WAL[T]) FirstOffset() (index uint64, err error) { + return walLog.log.FirstIndex() +} + +// LastOffset returns the last written offset/index of the log +func (walLog *WAL[T]) LastOffset() (index uint64, err error) { + return walLog.log.LastIndex() +} + +func (walLog *WAL[T]) NextOffset() (index uint64, err error) { + lastOffset, err := walLog.log.LastIndex() + if err != nil { + return 0, err + } + return lastOffset + 1, nil +} + +// ReadAt will read the log entry at the provided index +func (walLog *WAL[T]) ReadAt(index uint64) (T, error) { + var zero T + bz, err := walLog.log.Read(index) + if err != nil { + return zero, fmt.Errorf("read log failed, %w", err) + } + entry, err := walLog.unmarshal(bz) + if err != nil { + return zero, fmt.Errorf("unmarshal rlog failed, %w", err) + } + return entry, nil +} + +// Replay will read the replay log and process each log entry with the provided function +func (walLog *WAL[T]) Replay(start uint64, end uint64, processFn func(index uint64, entry T) error) error { + for i := start; i <= end; i++ { + bz, err := walLog.log.Read(i) + if err != nil { + return fmt.Errorf("read log failed, %w", err) + } + entry, err := walLog.unmarshal(bz) + if err != nil { + return fmt.Errorf("unmarshal rlog failed, %w", err) + } + err = processFn(i, entry) + if err != nil { + return err + } + } + return nil +} + +func (walLog *WAL[T]) startPruning(keepRecent uint64, pruneInterval time.Duration) { + walLog.wg.Add(1) + go func() { + defer walLog.wg.Done() + ticker := time.NewTicker(pruneInterval) + defer ticker.Stop() + for { + select { + case <-walLog.closeCh: + return + case <-ticker.C: + lastIndex, err := walLog.log.LastIndex() + if err != nil { + walLog.logger.Error("failed to get last index for pruning", "err", err) + continue + } + firstIndex, err := walLog.log.FirstIndex() + if err != nil { + walLog.logger.Error("failed to get first index for pruning", "err", err) + continue + } + if lastIndex > keepRecent && (lastIndex-keepRecent) > firstIndex { + prunePos := lastIndex - keepRecent + if err := walLog.TruncateBefore(prunePos); err != nil { + walLog.logger.Error(fmt.Sprintf("failed to prune changelog till index %d", prunePos), "err", err) + } + } + } + } + }() +} + +func (walLog *WAL[T]) Close() error { + walLog.mtx.Lock() + defer walLog.mtx.Unlock() + // Close should only be executed once. + if walLog.isClosed { + return nil + } + // Signal background goroutines to stop. + close(walLog.closeCh) + if walLog.writeChannel != nil { + close(walLog.writeChannel) + walLog.writeChannel = nil + } + // Wait for all background goroutines (pruning + async write) to finish. + walLog.wg.Wait() + walLog.isClosed = true + return walLog.log.Close() +} + +// recordAsyncWriteErr records the first async write error (non-blocking). +func (walLog *WAL[T]) recordAsyncWriteErr(err error) { + if err == nil { + return + } + select { + case walLog.asyncWriteErrCh <- err: + default: + // already recorded + } +} + +// getAsyncWriteErrLocked returns the async write error if present. +// To keep the error "sticky" without an extra cached field, we implement +// a "peek" by reading once and then non-blocking re-inserting the same +// error back into the buffered channel. +// Caller must hold walLog.mtx (read lock is sufficient). +func (walLog *WAL[T]) getAsyncWriteErrLocked() error { + select { + case err := <-walLog.asyncWriteErrCh: + // Put it back so subsequent callers still observe it. + select { + case walLog.asyncWriteErrCh <- err: + default: + } + return err + default: + return nil + } +} + +// open opens the replay log, try to truncate the corrupted tail if there's any +func open(dir string, opts *wal.Options) (*wal.Log, error) { + if opts == nil { + opts = wal.DefaultOptions + } + rlog, err := wal.Open(dir, opts) + if errors.Is(err, wal.ErrCorrupt) { + // try to truncate corrupted tail + var fis []os.DirEntry + fis, err = os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("read wal dir fail: %w", err) + } + var lastSeg string + for _, fi := range fis { + if fi.IsDir() || len(fi.Name()) < 20 { + continue + } + lastSeg = fi.Name() + } + + if len(lastSeg) == 0 { + return nil, err + } + if err = truncateCorruptedTail(filepath.Join(dir, lastSeg), opts.LogFormat); err != nil { + return nil, fmt.Errorf("truncate corrupted tail fail: %w", err) + } + + // try again + return wal.Open(dir, opts) + } + return rlog, err +} diff --git a/sei-db/wal/wal_test.go b/sei-db/wal/wal_test.go new file mode 100644 index 0000000000..76eaaf268b --- /dev/null +++ b/sei-db/wal/wal_test.go @@ -0,0 +1,624 @@ +package wal + +import ( + "fmt" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/tidwall/wal" + + "github.com/sei-protocol/sei-chain/sei-db/common/logger" + "github.com/sei-protocol/sei-chain/sei-db/proto" + iavl "github.com/sei-protocol/sei-chain/sei-iavl" +) + +var ( + ChangeSets = []iavl.ChangeSet{ + {Pairs: MockKVPairs("hello", "world")}, + {Pairs: MockKVPairs("hello1", "world1", "hello2", "world2")}, + {Pairs: MockKVPairs("hello3", "world3")}, + } + + // marshal/unmarshal functions for testing + marshalEntry = func(e proto.ChangelogEntry) ([]byte, error) { return e.Marshal() } + unmarshalEntry = func(data []byte) (proto.ChangelogEntry, error) { + var e proto.ChangelogEntry + err := e.Unmarshal(data) + return e, err + } +) + +func TestOpenAndCorruptedTail(t *testing.T) { + opts := &wal.Options{ + LogFormat: wal.JSON, + } + dir := t.TempDir() + + testCases := []struct { + name string + logs []byte + lastIndex uint64 + }{ + {"failure-1", []byte("\n"), 0}, + {"failure-2", []byte(`{}` + "\n"), 0}, + {"failure-3", []byte(`{"index":"1"}` + "\n"), 0}, + {"failure-4", []byte(`{"index":"1","data":"?"}`), 0}, + {"failure-5", []byte(`{"index":1,"data":"?"}` + "\n" + `{"index":"1","data":"?"}`), 1}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := os.WriteFile(filepath.Join(dir, "00000000000000000001"), tc.logs, 0o600) + require.NoError(t, err) + + _, err = wal.Open(dir, opts) + require.Equal(t, wal.ErrCorrupt, err) + + log, err := open(dir, opts) + require.NoError(t, err) + + lastIndex, err := log.LastIndex() + require.NoError(t, err) + require.Equal(t, tc.lastIndex, lastIndex) + }) + } +} + +func TestReplay(t *testing.T) { + changelog := prepareTestData(t) + var total = 0 + err := changelog.Replay(1, 2, func(index uint64, entry proto.ChangelogEntry) error { + total++ + switch index { + case 1: + require.Equal(t, "test", entry.Changesets[0].Name) + require.Equal(t, []byte("hello"), entry.Changesets[0].Changeset.Pairs[0].Key) + require.Equal(t, []byte("world"), entry.Changesets[0].Changeset.Pairs[0].Value) + case 2: + require.Equal(t, []byte("hello1"), entry.Changesets[0].Changeset.Pairs[0].Key) + require.Equal(t, []byte("world1"), entry.Changesets[0].Changeset.Pairs[0].Value) + require.Equal(t, []byte("hello2"), entry.Changesets[0].Changeset.Pairs[1].Key) + require.Equal(t, []byte("world2"), entry.Changesets[0].Changeset.Pairs[1].Value) + default: + require.Fail(t, fmt.Sprintf("unexpected index %d", index)) + } + return nil + }) + require.NoError(t, err) + require.Equal(t, 2, total) + err = changelog.Close() + require.NoError(t, err) +} + +func TestRandomRead(t *testing.T) { + changelog := prepareTestData(t) + entry, err := changelog.ReadAt(2) + require.NoError(t, err) + require.Equal(t, []byte("hello1"), entry.Changesets[0].Changeset.Pairs[0].Key) + require.Equal(t, []byte("world1"), entry.Changesets[0].Changeset.Pairs[0].Value) + require.Equal(t, []byte("hello2"), entry.Changesets[0].Changeset.Pairs[1].Key) + require.Equal(t, []byte("world2"), entry.Changesets[0].Changeset.Pairs[1].Value) + entry, err = changelog.ReadAt(1) + require.NoError(t, err) + require.Equal(t, []byte("hello"), entry.Changesets[0].Changeset.Pairs[0].Key) + require.Equal(t, []byte("world"), entry.Changesets[0].Changeset.Pairs[0].Value) + entry, err = changelog.ReadAt(3) + require.NoError(t, err) + require.Equal(t, []byte("hello3"), entry.Changesets[0].Changeset.Pairs[0].Key) + require.Equal(t, []byte("world3"), entry.Changesets[0].Changeset.Pairs[0].Value) +} + +func prepareTestData(t *testing.T) *WAL[proto.ChangelogEntry] { + dir := t.TempDir() + changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + require.NoError(t, err) + writeTestData(t, changelog) + return changelog +} + +func writeTestData(t *testing.T, changelog *WAL[proto.ChangelogEntry]) { + for _, changes := range ChangeSets { + cs := []*proto.NamedChangeSet{ + { + Name: "test", + Changeset: changes, + }, + } + entry := proto.ChangelogEntry{} + entry.Changesets = cs + require.NoError(t, changelog.Write(entry)) + } +} + +func TestSynchronousWrite(t *testing.T) { + changelog := prepareTestData(t) + lastIndex, err := changelog.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(3), lastIndex) + +} + +func TestAsyncWrite(t *testing.T) { + dir := t.TempDir() + changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{WriteBufferSize: 10}) + require.NoError(t, err) + for _, changes := range ChangeSets { + cs := []*proto.NamedChangeSet{ + { + Name: "test", + Changeset: changes, + }, + } + entry := &proto.ChangelogEntry{} + entry.Changesets = cs + err := changelog.Write(*entry) + require.NoError(t, err) + } + err = changelog.Close() + require.NoError(t, err) + changelog, err = NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{WriteBufferSize: 10}) + require.NoError(t, err) + lastIndex, err := changelog.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(3), lastIndex) +} + +func TestOpenWithNilOptions(t *testing.T) { + dir := t.TempDir() + + // Test that open function handles nil options correctly + log, err := open(dir, nil) + require.NoError(t, err) + require.NotNil(t, log) + + // Verify the log is functional by checking first and last index + firstIndex, err := log.FirstIndex() + require.NoError(t, err) + require.Equal(t, uint64(0), firstIndex) + + lastIndex, err := log.LastIndex() + require.NoError(t, err) + require.Equal(t, uint64(0), lastIndex) + + // Clean up + err = log.Close() + require.NoError(t, err) +} + +func TestTruncateAfter(t *testing.T) { + changelog := prepareTestData(t) + t.Cleanup(func() { require.NoError(t, changelog.Close()) }) + + // Verify we have 3 entries + lastIndex, err := changelog.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(3), lastIndex) + + // Truncate after index 2 (removes entry 3) + err = changelog.TruncateAfter(2) + require.NoError(t, err) + + // Verify last index is now 2 + lastIndex, err = changelog.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(2), lastIndex) + + // Verify nextOffset was updated - write a new entry and check its index + entry := &proto.ChangelogEntry{} + entry.Changesets = []*proto.NamedChangeSet{{Name: "new", Changeset: iavl.ChangeSet{Pairs: MockKVPairs("new", "entry")}}} + err = changelog.Write(*entry) + require.NoError(t, err) + + lastIndex, err = changelog.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(3), lastIndex) +} + +func TestTruncateBefore(t *testing.T) { + changelog := prepareTestData(t) + t.Cleanup(func() { require.NoError(t, changelog.Close()) }) + + // Verify we have 3 entries starting at 1 + firstIndex, err := changelog.FirstOffset() + require.NoError(t, err) + require.Equal(t, uint64(1), firstIndex) + + lastIndex, err := changelog.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(3), lastIndex) + + // Truncate before index 2 (removes entry 1) + err = changelog.TruncateBefore(2) + require.NoError(t, err) + + // Verify first index is now 2 + firstIndex, err = changelog.FirstOffset() + require.NoError(t, err) + require.Equal(t, uint64(2), firstIndex) + + // Last index should still be 3 + lastIndex, err = changelog.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(3), lastIndex) + + // Verify entry 2 is still readable + entry, err := changelog.ReadAt(2) + require.NoError(t, err) + require.Equal(t, []byte("hello1"), entry.Changesets[0].Changeset.Pairs[0].Key) +} + +func TestCloseSyncMode(t *testing.T) { + dir := t.TempDir() + changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + require.NoError(t, err) + + // Write some data in sync mode + writeTestData(t, changelog) + + // Close the changelog + err = changelog.Close() + require.NoError(t, err) + + // Verify isClosed is set + require.True(t, changelog.isClosed) + + // Reopen and verify data persisted + changelog2, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, changelog2.Close()) }) + + lastIndex, err := changelog2.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(3), lastIndex) +} + +func TestReadAtNonExistent(t *testing.T) { + changelog := prepareTestData(t) + t.Cleanup(func() { require.NoError(t, changelog.Close()) }) + + // Try to read an entry that doesn't exist + _, err := changelog.ReadAt(100) + require.Error(t, err) +} + +func TestReplayWithError(t *testing.T) { + changelog := prepareTestData(t) + t.Cleanup(func() { require.NoError(t, changelog.Close()) }) + + // Replay with a function that returns an error + expectedErr := fmt.Errorf("test error") + err := changelog.Replay(1, 3, func(index uint64, entry proto.ChangelogEntry) error { + if index == 2 { + return expectedErr + } + return nil + }) + require.Error(t, err) + require.Equal(t, expectedErr, err) +} + +func TestReopenAndContinueWrite(t *testing.T) { + dir := t.TempDir() + + // Create and write initial data + changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + require.NoError(t, err) + writeTestData(t, changelog) + err = changelog.Close() + require.NoError(t, err) + + // Reopen and continue writing + changelog2, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + require.NoError(t, err) + + // Verify nextOffset is correctly set after reopen + lastIndex, err := changelog2.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(3), lastIndex) + + // Write more data + entry := &proto.ChangelogEntry{} + entry.Changesets = []*proto.NamedChangeSet{{Name: "continued", Changeset: iavl.ChangeSet{Pairs: MockKVPairs("key4", "value4")}}} + err = changelog2.Write(*entry) + require.NoError(t, err) + + // Verify new entry is at index 4 + lastIndex, err = changelog2.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(4), lastIndex) + + // Verify data integrity + readEntry, err := changelog2.ReadAt(4) + require.NoError(t, err) + require.Equal(t, "continued", readEntry.Changesets[0].Name) + require.Equal(t, []byte("key4"), readEntry.Changesets[0].Changeset.Pairs[0].Key) + + err = changelog2.Close() + require.NoError(t, err) +} + +func TestEmptyLog(t *testing.T) { + dir := t.TempDir() + changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, changelog.Close()) }) + + // Empty log should have 0 for both first and last index + firstIndex, err := changelog.FirstOffset() + require.NoError(t, err) + require.Equal(t, uint64(0), firstIndex) + + lastIndex, err := changelog.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(0), lastIndex) +} + +func TestCheckErrorNoError(t *testing.T) { + dir := t.TempDir() + changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{WriteBufferSize: 10}) + require.NoError(t, err) + + // Write some data to initialize async mode + entry := &proto.ChangelogEntry{} + entry.Changesets = []*proto.NamedChangeSet{{Name: "test", Changeset: iavl.ChangeSet{Pairs: MockKVPairs("k", "v")}}} + err = changelog.Write(*entry) + require.NoError(t, err) + + err = changelog.Close() + require.NoError(t, err) +} + +func TestFirstAndLastOffset(t *testing.T) { + changelog := prepareTestData(t) + t.Cleanup(func() { require.NoError(t, changelog.Close()) }) + + firstIndex, err := changelog.FirstOffset() + require.NoError(t, err) + require.Equal(t, uint64(1), firstIndex) + + lastIndex, err := changelog.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(3), lastIndex) +} + +func TestAsyncWriteReopenAndContinue(t *testing.T) { + dir := t.TempDir() + + // Create with async write and write data + changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{WriteBufferSize: 10}) + require.NoError(t, err) + + for _, changes := range ChangeSets { + cs := []*proto.NamedChangeSet{{Name: "test", Changeset: changes}} + entry := &proto.ChangelogEntry{Changesets: cs} + err := changelog.Write(*entry) + require.NoError(t, err) + } + + err = changelog.Close() + require.NoError(t, err) + + // Reopen with async write and continue + changelog2, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{WriteBufferSize: 10}) + require.NoError(t, err) + + // Write more entries + for i := 0; i < 3; i++ { + entry := &proto.ChangelogEntry{} + entry.Changesets = []*proto.NamedChangeSet{{Name: fmt.Sprintf("batch2-%d", i), Changeset: iavl.ChangeSet{Pairs: MockKVPairs("k", "v")}}} + err := changelog2.Write(*entry) + require.NoError(t, err) + } + + err = changelog2.Close() + require.NoError(t, err) + + // Reopen and verify all 6 entries + changelog3, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, changelog3.Close()) }) + + lastIndex, err := changelog3.LastOffset() + require.NoError(t, err) + require.Equal(t, uint64(6), lastIndex) +} + +func TestReplaySingleEntry(t *testing.T) { + changelog := prepareTestData(t) + t.Cleanup(func() { require.NoError(t, changelog.Close()) }) + + var count int + err := changelog.Replay(2, 2, func(index uint64, entry proto.ChangelogEntry) error { + count++ + require.Equal(t, uint64(2), index) + return nil + }) + require.NoError(t, err) + require.Equal(t, 1, count) +} + +func TestWriteMultipleChangesets(t *testing.T) { + dir := t.TempDir() + changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, changelog.Close()) }) + + // Write entry with multiple changesets + entry := &proto.ChangelogEntry{ + Changesets: []*proto.NamedChangeSet{ + {Name: "store1", Changeset: iavl.ChangeSet{Pairs: MockKVPairs("a", "1")}}, + {Name: "store2", Changeset: iavl.ChangeSet{Pairs: MockKVPairs("b", "2")}}, + {Name: "store3", Changeset: iavl.ChangeSet{Pairs: MockKVPairs("c", "3")}}, + }, + } + err = changelog.Write(*entry) + require.NoError(t, err) + + // Read and verify + readEntry, err := changelog.ReadAt(1) + require.NoError(t, err) + require.Len(t, readEntry.Changesets, 3) + require.Equal(t, "store1", readEntry.Changesets[0].Name) + require.Equal(t, "store2", readEntry.Changesets[1].Name) + require.Equal(t, "store3", readEntry.Changesets[2].Name) +} + +func TestConcurrentCloseWithInFlightAsyncWrites(t *testing.T) { + dir := t.TempDir() + changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{WriteBufferSize: 8}) + require.NoError(t, err) + + // Intentionally avoid t.Cleanup here: we want Close() to race with in-flight async writes. + + // Writers: keep calling Write() until it returns an error (which should happen once Close() starts). + // If Write() or Close() deadlocks, the test will time out waiting for the goroutines to exit. + var wg sync.WaitGroup + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + entry := proto.ChangelogEntry{ + Changesets: []*proto.NamedChangeSet{{ + Name: "test", + Changeset: iavl.ChangeSet{Pairs: MockKVPairs("k", "v")}, + }}, + } + if err := changelog.Write(entry); err != nil { + return + } + } + }() + } + + // Ensure we actually have in-flight async activity before closing. + require.Eventually(t, func() bool { + last, err := changelog.LastOffset() + return err == nil && last > 0 + }, 1*time.Second, 10*time.Millisecond, "expected some writes before Close()") + + closeDone := make(chan struct{}) + closeErr := make(chan error, 1) + go func() { + closeErr <- changelog.Close() + close(closeDone) + }() + + // Wait for writers to observe Close() and exit. + writersDone := make(chan struct{}) + go func() { wg.Wait(); close(writersDone) }() + require.Eventually(t, func() bool { + select { + case <-writersDone: + return true + default: + return false + } + }, 3*time.Second, 10*time.Millisecond, "writers did not exit (possible deadlock)") + + // Ensure Close() returns too. + require.Eventually(t, func() bool { + select { + case <-closeDone: + return true + default: + return false + } + }, 3*time.Second, 10*time.Millisecond, "Close() did not return (possible deadlock)") + + require.NoError(t, <-closeErr) +} + +func TestConcurrentTruncateBeforeWithAsyncWrites(t *testing.T) { + dir := t.TempDir() + changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{ + WriteBufferSize: 10, + KeepRecent: 10, + PruneInterval: 1 * time.Millisecond, + }) + require.NoError(t, err) + + const ( + totalWrites = 50 + ) + + // Write a bunch of entries (async writes). We'll wait until they're all persisted. + for i := 1; i <= totalWrites; i++ { + entry := proto.ChangelogEntry{ + Changesets: []*proto.NamedChangeSet{{ + Name: "test", + Changeset: iavl.ChangeSet{Pairs: MockKVPairs(fmt.Sprintf("k-%d", i), "v")}, + }}, + } + require.NoError(t, changelog.Write(entry)) + } + + // Ensure async writer has flushed to disk. + require.Eventually(t, func() bool { + last, err := changelog.LastOffset() + return err == nil && last == uint64(totalWrites) + }, 3*time.Second, 10*time.Millisecond, "async writes did not flush") + + // Let the background pruning goroutine run and advance FirstOffset. + require.Eventually(t, func() bool { + first, err := changelog.FirstOffset() + return err == nil && first > 1 + }, 3*time.Second, 10*time.Millisecond, "background pruning did not advance FirstOffset") + + // Manual front truncation while pruning is enabled. + firstBefore, err := changelog.FirstOffset() + require.NoError(t, err) + last, err := changelog.LastOffset() + require.NoError(t, err) + require.True(t, firstBefore < last, "expected a non-empty range after writes") + + require.NoError(t, changelog.TruncateBefore(firstBefore+1)) + require.Eventually(t, func() bool { + first, err := changelog.FirstOffset() + return err == nil && first >= firstBefore+1 + }, 3*time.Second, 10*time.Millisecond, "manual truncation did not take effect") + + // Read first + last entries to ensure no corruption (decode succeeds; expected structure). + first, err := changelog.FirstOffset() + require.NoError(t, err) + last, err = changelog.LastOffset() + require.NoError(t, err) + require.True(t, first <= last, "invalid WAL range after pruning/truncation") + + firstEntry, err := changelog.ReadAt(first) + require.NoError(t, err) + require.NotEmpty(t, firstEntry.Changesets) + require.Equal(t, "test", firstEntry.Changesets[0].Name) + require.NotEmpty(t, firstEntry.Changesets[0].Changeset.Pairs) + + lastEntry, err := changelog.ReadAt(last) + require.NoError(t, err) + require.NotEmpty(t, lastEntry.Changesets) + require.Equal(t, "test", lastEntry.Changesets[0].Name) + require.NotEmpty(t, lastEntry.Changesets[0].Changeset.Pairs) + + require.NoError(t, changelog.Close()) +} + +func TestGetLastIndex(t *testing.T) { + dir := t.TempDir() + changelog, err := NewWAL(marshalEntry, unmarshalEntry, logger.NewNopLogger(), dir, Config{}) + require.NoError(t, err) + writeTestData(t, changelog) + err = changelog.Close() + require.NoError(t, err) + + // Use utility function to get last index without opening stream + lastIndex, err := GetLastIndex(dir) + require.NoError(t, err) + require.Equal(t, uint64(3), lastIndex) +} + +func TestLogPath(t *testing.T) { + path := LogPath("/some/dir") + require.Equal(t, "/some/dir/changelog", path) +}