-
Notifications
You must be signed in to change notification settings - Fork 6
drpcmanager: replace streamBuffer with a streams registry #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| // Copyright (C) 2026 Cockroach Labs. | ||
| // See LICENSE for copying information. | ||
|
|
||
| package drpcmanager | ||
|
|
||
| import ( | ||
| "sync" | ||
|
|
||
| "storj.io/drpc/drpcstream" | ||
| ) | ||
|
|
||
| // activeStreams is a thread-safe map of stream IDs to stream objects. | ||
| // It is used by the Manager to track active streams for lifecycle management. | ||
| type activeStreams struct { | ||
| mu sync.RWMutex | ||
| streams map[uint64]*drpcstream.Stream | ||
| closed bool | ||
| } | ||
|
|
||
| func newActiveStreams() *activeStreams { | ||
| return &activeStreams{ | ||
| streams: make(map[uint64]*drpcstream.Stream), | ||
| } | ||
| } | ||
|
|
||
| // Add adds a stream. It returns an error if the collection is closed or if a | ||
| // stream with the same ID already exists. | ||
| func (r *activeStreams) Add(id uint64, stream *drpcstream.Stream) error { | ||
| if stream == nil { | ||
| return managerClosed.New("stream can't be nil") | ||
| } | ||
|
|
||
| r.mu.Lock() | ||
| defer r.mu.Unlock() | ||
|
|
||
| if r.closed { | ||
| return managerClosed.New("add to closed collection") | ||
| } | ||
| if _, ok := r.streams[id]; ok { | ||
| return managerClosed.New("duplicate stream id") | ||
| } | ||
| r.streams[id] = stream | ||
| return nil | ||
| } | ||
|
|
||
| // Remove removes a stream. It is a no-op if the stream is not present or if | ||
| // the collection has been closed. | ||
| func (r *activeStreams) Remove(id uint64) { | ||
| r.mu.Lock() | ||
| defer r.mu.Unlock() | ||
|
|
||
| if r.streams != nil { | ||
| delete(r.streams, id) | ||
| } | ||
| } | ||
|
|
||
| // Get returns the stream for the given ID and whether it was found. | ||
| func (r *activeStreams) Get(id uint64) (*drpcstream.Stream, bool) { | ||
| r.mu.RLock() | ||
| defer r.mu.RUnlock() | ||
|
|
||
| s, ok := r.streams[id] | ||
| return s, ok | ||
| } | ||
|
|
||
| // GetLatest returns the stream with the highest ID, or nil if empty. | ||
| func (r *activeStreams) GetLatest() *drpcstream.Stream { | ||
| r.mu.RLock() | ||
| defer r.mu.RUnlock() | ||
|
|
||
| var latest *drpcstream.Stream | ||
| for _, s := range r.streams { | ||
|
shubhamdhama marked this conversation as resolved.
|
||
| if latest == nil || latest.ID() < s.ID() { | ||
| latest = s | ||
| } | ||
| } | ||
| return latest | ||
| } | ||
|
|
||
| // Close marks the collection as closed, preventing future Add calls. | ||
| // It does not cancel any streams. | ||
| func (r *activeStreams) Close() { | ||
| r.mu.Lock() | ||
| defer r.mu.Unlock() | ||
|
|
||
| r.closed = true | ||
| } | ||
|
|
||
| // ForEach calls fn for each active stream. The collection is read-locked | ||
| // during iteration. | ||
| func (r *activeStreams) ForEach(fn func(*drpcstream.Stream)) { | ||
|
shubhamdhama marked this conversation as resolved.
|
||
| r.mu.RLock() | ||
| defer r.mu.RUnlock() | ||
|
|
||
| for _, s := range r.streams { | ||
| fn(s) | ||
| } | ||
| } | ||
|
|
||
| // Len returns the number of active streams. | ||
| func (r *activeStreams) Len() int { | ||
| r.mu.RLock() | ||
| defer r.mu.RUnlock() | ||
|
|
||
| return len(r.streams) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| // Copyright (C) 2026 Cockroach Labs. | ||
| // See LICENSE for copying information. | ||
|
|
||
| package drpcmanager | ||
|
|
||
| import ( | ||
| "context" | ||
| "testing" | ||
|
|
||
| "github.com/zeebo/assert" | ||
|
|
||
| "storj.io/drpc/drpcstream" | ||
| "storj.io/drpc/drpcwire" | ||
| ) | ||
|
|
||
| func testStream(id uint64) *drpcstream.Stream { | ||
| return drpcstream.New(context.Background(), id, &drpcwire.Writer{}) | ||
| } | ||
|
|
||
| func TestActiveStreams_AddAndGet(t *testing.T) { | ||
| streams := newActiveStreams() | ||
| s := testStream(1) | ||
|
|
||
| assert.NoError(t, streams.Add(1, s)) | ||
|
|
||
| got, ok := streams.Get(1) | ||
| assert.That(t, ok) | ||
| assert.Equal(t, got, s) | ||
| } | ||
|
|
||
| func TestActiveStreams_GetMissing(t *testing.T) { | ||
| streams := newActiveStreams() | ||
|
|
||
| got, ok := streams.Get(42) | ||
| assert.That(t, !ok) | ||
| assert.Nil(t, got) | ||
| } | ||
|
|
||
| func TestActiveStreams_Remove(t *testing.T) { | ||
| streams := newActiveStreams() | ||
| s := testStream(1) | ||
|
|
||
| assert.NoError(t, streams.Add(1, s)) | ||
| assert.Equal(t, streams.Len(), 1) | ||
|
|
||
| streams.Remove(1) | ||
|
|
||
| _, ok := streams.Get(1) | ||
| assert.That(t, !ok) | ||
| assert.Equal(t, streams.Len(), 0) | ||
| } | ||
|
|
||
| func TestActiveStreams_RemoveIdempotent(t *testing.T) { | ||
| streams := newActiveStreams() | ||
|
|
||
| // must not panic when removing a non-existent ID | ||
| streams.Remove(99) | ||
| } | ||
|
|
||
| func TestActiveStreams_DuplicateAdd(t *testing.T) { | ||
| streams := newActiveStreams() | ||
| s1 := testStream(1) | ||
| s2 := testStream(1) | ||
|
|
||
| assert.NoError(t, streams.Add(1, s1)) | ||
| assert.Error(t, streams.Add(1, s2)) | ||
|
|
||
| // original stream is still present | ||
| got, ok := streams.Get(1) | ||
| assert.That(t, ok) | ||
| assert.Equal(t, got, s1) | ||
| } | ||
|
|
||
| func TestActiveStreams_AddAfterClose(t *testing.T) { | ||
| streams := newActiveStreams() | ||
| streams.Close() | ||
|
|
||
| err := streams.Add(1, testStream(1)) | ||
| assert.Error(t, err) | ||
| } | ||
|
|
||
| func TestActiveStreams_RemoveAfterClose(t *testing.T) { | ||
| streams := newActiveStreams() | ||
| s := testStream(1) | ||
| assert.NoError(t, streams.Add(1, s)) | ||
|
|
||
| streams.Close() | ||
|
|
||
| // must not panic | ||
| streams.Remove(1) | ||
| } | ||
|
|
||
| func TestActiveStreams_Len(t *testing.T) { | ||
| streams := newActiveStreams() | ||
| assert.Equal(t, streams.Len(), 0) | ||
|
|
||
| assert.NoError(t, streams.Add(1, testStream(1))) | ||
| assert.Equal(t, streams.Len(), 1) | ||
|
|
||
| assert.NoError(t, streams.Add(2, testStream(2))) | ||
| assert.Equal(t, streams.Len(), 2) | ||
|
|
||
| streams.Remove(1) | ||
| assert.Equal(t, streams.Len(), 1) | ||
| } | ||
|
|
||
| func TestActiveStreams_ForEach(t *testing.T) { | ||
| streams := newActiveStreams() | ||
| s1 := testStream(1) | ||
| s2 := testStream(2) | ||
| s3 := testStream(3) | ||
|
|
||
| assert.NoError(t, streams.Add(1, s1)) | ||
| assert.NoError(t, streams.Add(2, s2)) | ||
| assert.NoError(t, streams.Add(3, s3)) | ||
|
|
||
| seen := make(map[uint64]*drpcstream.Stream) | ||
| streams.ForEach(func(s *drpcstream.Stream) { | ||
| seen[s.ID()] = s | ||
| }) | ||
|
|
||
| assert.Equal(t, len(seen), 3) | ||
| assert.Equal(t, seen[1], s1) | ||
| assert.Equal(t, seen[2], s2) | ||
| assert.Equal(t, seen[3], s3) | ||
| } | ||
|
|
||
| func TestActiveStreams_ForEach_Empty(t *testing.T) { | ||
| streams := newActiveStreams() | ||
|
|
||
| count := 0 | ||
| streams.ForEach(func(_ *drpcstream.Stream) { count++ }) | ||
| assert.Equal(t, count, 0) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ import ( | |
| "net" | ||
| "strings" | ||
| "sync" | ||
| "sync/atomic" | ||
| "syscall" | ||
| "time" | ||
|
|
||
|
|
@@ -81,10 +82,14 @@ type Manager struct { | |
|
|
||
| wg sync.WaitGroup // tracks active manageStream goroutines | ||
|
|
||
| sem drpcsignal.Chan // held by the active stream | ||
| sbuf streamBuffer // largest stream id created | ||
| // streams tracks active streams. Currently holds at most one active stream; | ||
| // a second may briefly coexist during stream handoff (old stream's Remove | ||
| // races with new stream's Add). | ||
| streams *activeStreams | ||
|
|
||
| pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream | ||
| sem drpcsignal.Chan // held by the active stream | ||
|
|
||
| pdone drpcsignal.Chan // signals when NewServerStream has added the new stream | ||
| invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream | ||
|
|
||
| // Below fields are owned by the manageReader goroutine, used in handleInvokeFrame. | ||
|
|
@@ -123,9 +128,6 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Manager { | |
| invokes: make(chan invokeInfo), | ||
| } | ||
|
|
||
| // initialize the stream buffer | ||
| m.sbuf.init() | ||
|
|
||
| // this semaphore controls the number of concurrent streams. it MUST be 1. | ||
| m.sem.Make(1) | ||
|
|
||
|
|
@@ -134,6 +136,7 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Manager { | |
| m.pdone.Make(1) | ||
|
|
||
| m.pa = drpcwire.NewPacketAssembler() | ||
| m.streams = newActiveStreams() | ||
|
|
||
| // set the internal stream options | ||
| drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr) | ||
|
|
@@ -186,7 +189,7 @@ func (m *Manager) acquireSemaphore(ctx context.Context) error { | |
| // longer make any reads or writes on the transport. It exits early if the | ||
| // context is canceled or the manager is terminated. | ||
| func (m *Manager) waitForPreviousStream(ctx context.Context) (err error) { | ||
| prev := m.sbuf.Get() | ||
| prev := m.streams.GetLatest() | ||
| if prev == nil { | ||
| return nil | ||
| } | ||
|
|
@@ -217,7 +220,7 @@ func (m *Manager) terminate(err error) { | |
| if m.sigs.term.Set(err) { | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [pp] Let's depend on
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll do in follow-up. |
||
| m.log("TERM", func() string { return fmt.Sprint(err) }) | ||
| m.sigs.tport.Set(m.tr.Close()) | ||
| m.sbuf.Close() | ||
| m.streams.Close() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -249,7 +252,7 @@ func (m *Manager) manageReader() { | |
| return | ||
| } | ||
|
|
||
| switch curr := m.sbuf.Get(); { | ||
| switch curr := m.streams.GetLatest(); { | ||
| // If the frame is for the current stream, deliver it. | ||
| case curr != nil && incomingFrame.ID.Stream == curr.ID(): | ||
| if err := curr.HandleFrame(incomingFrame); err != nil { | ||
|
|
@@ -272,14 +275,7 @@ func (m *Manager) manageReader() { | |
| } | ||
|
|
||
| default: | ||
| // A non-invoke frame arrived for a stream that doesn't exist yet | ||
| // (curr is nil or incomingFrame.ID.Stream > curr.ID). The first | ||
| // frame of a new stream must be KindInvoke or KindInvokeMetadata. | ||
| m.terminate(managerClosed.Wrap(drpc.ProtocolError.New( | ||
| "first frame of a new stream must be Invoke, got %v (ID:%v)", | ||
| incomingFrame.Kind, | ||
| incomingFrame.ID))) | ||
| return | ||
| m.log("DROP", incomingFrame.String) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -319,9 +315,9 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error { | |
| // Invoke packet completes the sequence. Send to NewServerStream. | ||
| select { | ||
| case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: m.metadata}: | ||
| // Wait for NewServerStream to finish stream creation (including | ||
| // sbuf.Set) before reading the next frame. This guarantees curr | ||
| // is set for subsequent non-invoke packets. | ||
| // Wait for NewServerStream to finish stream creation before reading the | ||
| // next frame. This guarantees curr is set for subsequent non-invoke | ||
| // packets. | ||
| m.pdone.Recv() | ||
|
|
||
| m.pa.Reset() | ||
|
|
@@ -346,10 +342,13 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin | |
|
|
||
| stream := drpcstream.NewWithOptions(ctx, sid, m.wr, opts) | ||
|
|
||
| if err := m.streams.Add(sid, stream); err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| m.wg.Add(1) | ||
| go m.manageStream(ctx, stream) | ||
|
|
||
| m.sbuf.Set(stream) | ||
| m.log("STREAM", stream.String) | ||
|
|
||
| return stream, nil | ||
|
|
@@ -359,6 +358,7 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin | |
| // is finished, canceling the stream if the context is canceled. | ||
| func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) { | ||
| defer m.wg.Done() | ||
| defer m.streams.Remove(stream.ID()) | ||
| select { | ||
| case <-m.sigs.term.Signal(): | ||
| err := m.sigs.term.Err() | ||
|
|
@@ -429,7 +429,7 @@ func (m *Manager) Closed() <-chan struct{} { | |
| // the return result is only valid until the next call to NewClientStream or | ||
| // NewServerStream. | ||
| func (m *Manager) Unblocked() <-chan struct{} { | ||
| if prev := m.sbuf.Get(); prev != nil { | ||
| if prev := m.streams.GetLatest(); prev != nil { | ||
| return prev.Context().Done() | ||
| } | ||
| return closedCh | ||
|
|
@@ -506,9 +506,8 @@ func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Strea | |
| } | ||
| } | ||
| stream, err := m.newStream(ctx, pkt.sid, drpc.StreamKindServer, rpc) | ||
| // Signal pdone only after stream registration so that | ||
| // manageReader sees the new stream via sbuf.Get() when it reads | ||
| // the next frame. | ||
| // Signal pdone only after adding the stream so that manageReader sees | ||
| // the new stream in activeStreams when it reads the next frame. | ||
| m.pdone.Send() | ||
| return stream, rpc, err | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(for future commits), if you make this an atomic boolean, you can do this check outside. You can also move this check outside once you rely on manager closed signal.