Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions drpcmanager/active_streams.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright (C) 2026 Cockroach Labs.
// See LICENSE for copying information.

package drpcmanager

import (
"sync"

"storj.io/drpc/drpcstream"
)

// activeStreams is a thread-safe map of stream IDs to stream objects.
// It is used by the Manager to track active streams for lifecycle management.
type activeStreams struct {
mu sync.RWMutex
streams map[uint64]*drpcstream.Stream
closed bool
}

func newActiveStreams() *activeStreams {
return &activeStreams{
streams: make(map[uint64]*drpcstream.Stream),
}
}

// Add adds a stream. It returns an error if the collection is closed or if a
// stream with the same ID already exists.
func (r *activeStreams) Add(id uint64, stream *drpcstream.Stream) error {
if stream == nil {
return managerClosed.New("stream can't be nil")
}

r.mu.Lock()
defer r.mu.Unlock()

if r.closed {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(for future commits), if you make this an atomic boolean, you can do this check outside. You can also move this check outside once you rely on manager closed signal.

return managerClosed.New("add to closed collection")
}
if _, ok := r.streams[id]; ok {
return managerClosed.New("duplicate stream id")
}
r.streams[id] = stream
return nil
}

// Remove removes a stream. It is a no-op if the stream is not present or if
// the collection has been closed.
func (r *activeStreams) Remove(id uint64) {
r.mu.Lock()
defer r.mu.Unlock()

if r.streams != nil {
delete(r.streams, id)
}
}

// Get returns the stream for the given ID and whether it was found.
func (r *activeStreams) Get(id uint64) (*drpcstream.Stream, bool) {
r.mu.RLock()
defer r.mu.RUnlock()

s, ok := r.streams[id]
return s, ok
}

// GetLatest returns the stream with the highest ID, or nil if empty.
func (r *activeStreams) GetLatest() *drpcstream.Stream {
r.mu.RLock()
defer r.mu.RUnlock()

var latest *drpcstream.Stream
for _, s := range r.streams {
Comment thread
shubhamdhama marked this conversation as resolved.
if latest == nil || latest.ID() < s.ID() {
latest = s
}
}
return latest
}

// Close marks the collection as closed, preventing future Add calls.
// It does not cancel any streams.
func (r *activeStreams) Close() {
r.mu.Lock()
defer r.mu.Unlock()

r.closed = true
}

// ForEach calls fn for each active stream. The collection is read-locked
// during iteration.
func (r *activeStreams) ForEach(fn func(*drpcstream.Stream)) {
Comment thread
shubhamdhama marked this conversation as resolved.
r.mu.RLock()
defer r.mu.RUnlock()

for _, s := range r.streams {
fn(s)
}
}

// Len returns the number of active streams.
func (r *activeStreams) Len() int {
r.mu.RLock()
defer r.mu.RUnlock()

return len(r.streams)
}
134 changes: 134 additions & 0 deletions drpcmanager/active_streams_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (C) 2026 Cockroach Labs.
// See LICENSE for copying information.

package drpcmanager

import (
"context"
"testing"

"github.com/zeebo/assert"

"storj.io/drpc/drpcstream"
"storj.io/drpc/drpcwire"
)

func testStream(id uint64) *drpcstream.Stream {
return drpcstream.New(context.Background(), id, &drpcwire.Writer{})
}

func TestActiveStreams_AddAndGet(t *testing.T) {
streams := newActiveStreams()
s := testStream(1)

assert.NoError(t, streams.Add(1, s))

got, ok := streams.Get(1)
assert.That(t, ok)
assert.Equal(t, got, s)
}

func TestActiveStreams_GetMissing(t *testing.T) {
streams := newActiveStreams()

got, ok := streams.Get(42)
assert.That(t, !ok)
assert.Nil(t, got)
}

func TestActiveStreams_Remove(t *testing.T) {
streams := newActiveStreams()
s := testStream(1)

assert.NoError(t, streams.Add(1, s))
assert.Equal(t, streams.Len(), 1)

streams.Remove(1)

_, ok := streams.Get(1)
assert.That(t, !ok)
assert.Equal(t, streams.Len(), 0)
}

func TestActiveStreams_RemoveIdempotent(t *testing.T) {
streams := newActiveStreams()

// must not panic when removing a non-existent ID
streams.Remove(99)
}

func TestActiveStreams_DuplicateAdd(t *testing.T) {
streams := newActiveStreams()
s1 := testStream(1)
s2 := testStream(1)

assert.NoError(t, streams.Add(1, s1))
assert.Error(t, streams.Add(1, s2))

// original stream is still present
got, ok := streams.Get(1)
assert.That(t, ok)
assert.Equal(t, got, s1)
}

func TestActiveStreams_AddAfterClose(t *testing.T) {
streams := newActiveStreams()
streams.Close()

err := streams.Add(1, testStream(1))
assert.Error(t, err)
}

func TestActiveStreams_RemoveAfterClose(t *testing.T) {
streams := newActiveStreams()
s := testStream(1)
assert.NoError(t, streams.Add(1, s))

streams.Close()

// must not panic
streams.Remove(1)
}

func TestActiveStreams_Len(t *testing.T) {
streams := newActiveStreams()
assert.Equal(t, streams.Len(), 0)

assert.NoError(t, streams.Add(1, testStream(1)))
assert.Equal(t, streams.Len(), 1)

assert.NoError(t, streams.Add(2, testStream(2)))
assert.Equal(t, streams.Len(), 2)

streams.Remove(1)
assert.Equal(t, streams.Len(), 1)
}

func TestActiveStreams_ForEach(t *testing.T) {
streams := newActiveStreams()
s1 := testStream(1)
s2 := testStream(2)
s3 := testStream(3)

assert.NoError(t, streams.Add(1, s1))
assert.NoError(t, streams.Add(2, s2))
assert.NoError(t, streams.Add(3, s3))

seen := make(map[uint64]*drpcstream.Stream)
streams.ForEach(func(s *drpcstream.Stream) {
seen[s.ID()] = s
})

assert.Equal(t, len(seen), 3)
assert.Equal(t, seen[1], s1)
assert.Equal(t, seen[2], s2)
assert.Equal(t, seen[3], s3)
}

func TestActiveStreams_ForEach_Empty(t *testing.T) {
streams := newActiveStreams()

count := 0
streams.ForEach(func(_ *drpcstream.Stream) { count++ })
assert.Equal(t, count, 0)
}
49 changes: 24 additions & 25 deletions drpcmanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

Expand Down Expand Up @@ -81,10 +82,14 @@ type Manager struct {

wg sync.WaitGroup // tracks active manageStream goroutines

sem drpcsignal.Chan // held by the active stream
sbuf streamBuffer // largest stream id created
// streams tracks active streams. Currently holds at most one active stream;
// a second may briefly coexist during stream handoff (old stream's Remove
// races with new stream's Add).
streams *activeStreams

pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream
sem drpcsignal.Chan // held by the active stream

pdone drpcsignal.Chan // signals when NewServerStream has added the new stream
invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream

// Below fields are owned by the manageReader goroutine, used in handleInvokeFrame.
Expand Down Expand Up @@ -123,9 +128,6 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Manager {
invokes: make(chan invokeInfo),
}

// initialize the stream buffer
m.sbuf.init()

// this semaphore controls the number of concurrent streams. it MUST be 1.
m.sem.Make(1)

Expand All @@ -134,6 +136,7 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Manager {
m.pdone.Make(1)

m.pa = drpcwire.NewPacketAssembler()
m.streams = newActiveStreams()

// set the internal stream options
drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr)
Expand Down Expand Up @@ -186,7 +189,7 @@ func (m *Manager) acquireSemaphore(ctx context.Context) error {
// longer make any reads or writes on the transport. It exits early if the
// context is canceled or the manager is terminated.
func (m *Manager) waitForPreviousStream(ctx context.Context) (err error) {
prev := m.sbuf.Get()
prev := m.streams.GetLatest()
if prev == nil {
return nil
}
Expand Down Expand Up @@ -217,7 +220,7 @@ func (m *Manager) terminate(err error) {
if m.sigs.term.Set(err) {
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[pp] Let's depend on .term in the registry.Register. And close the streams in reg.Close and remove the .term dependency from manageStreams

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do in follow-up.

m.log("TERM", func() string { return fmt.Sprint(err) })
m.sigs.tport.Set(m.tr.Close())
m.sbuf.Close()
m.streams.Close()
}
}

Expand Down Expand Up @@ -249,7 +252,7 @@ func (m *Manager) manageReader() {
return
}

switch curr := m.sbuf.Get(); {
switch curr := m.streams.GetLatest(); {
// If the frame is for the current stream, deliver it.
case curr != nil && incomingFrame.ID.Stream == curr.ID():
if err := curr.HandleFrame(incomingFrame); err != nil {
Expand All @@ -272,14 +275,7 @@ func (m *Manager) manageReader() {
}

default:
// A non-invoke frame arrived for a stream that doesn't exist yet
// (curr is nil or incomingFrame.ID.Stream > curr.ID). The first
// frame of a new stream must be KindInvoke or KindInvokeMetadata.
m.terminate(managerClosed.Wrap(drpc.ProtocolError.New(
"first frame of a new stream must be Invoke, got %v (ID:%v)",
incomingFrame.Kind,
incomingFrame.ID)))
return
m.log("DROP", incomingFrame.String)
}
}
}
Expand Down Expand Up @@ -319,9 +315,9 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error {
// Invoke packet completes the sequence. Send to NewServerStream.
select {
case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: m.metadata}:
// Wait for NewServerStream to finish stream creation (including
// sbuf.Set) before reading the next frame. This guarantees curr
// is set for subsequent non-invoke packets.
// Wait for NewServerStream to finish stream creation before reading the
// next frame. This guarantees curr is set for subsequent non-invoke
// packets.
m.pdone.Recv()

m.pa.Reset()
Expand All @@ -346,10 +342,13 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin

stream := drpcstream.NewWithOptions(ctx, sid, m.wr, opts)

if err := m.streams.Add(sid, stream); err != nil {
return nil, err
}

m.wg.Add(1)
go m.manageStream(ctx, stream)

m.sbuf.Set(stream)
m.log("STREAM", stream.String)

return stream, nil
Expand All @@ -359,6 +358,7 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin
// is finished, canceling the stream if the context is canceled.
func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) {
defer m.wg.Done()
defer m.streams.Remove(stream.ID())
select {
case <-m.sigs.term.Signal():
err := m.sigs.term.Err()
Expand Down Expand Up @@ -429,7 +429,7 @@ func (m *Manager) Closed() <-chan struct{} {
// the return result is only valid until the next call to NewClientStream or
// NewServerStream.
func (m *Manager) Unblocked() <-chan struct{} {
if prev := m.sbuf.Get(); prev != nil {
if prev := m.streams.GetLatest(); prev != nil {
return prev.Context().Done()
}
return closedCh
Expand Down Expand Up @@ -506,9 +506,8 @@ func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Strea
}
}
stream, err := m.newStream(ctx, pkt.sid, drpc.StreamKindServer, rpc)
// Signal pdone only after stream registration so that
// manageReader sees the new stream via sbuf.Get() when it reads
// the next frame.
// Signal pdone only after adding the stream so that manageReader sees
// the new stream in activeStreams when it reads the next frame.
m.pdone.Send()
return stream, rpc, err
}
Expand Down
Loading