Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 52 additions & 40 deletions go/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ func (c *client) newStream(ctx context.Context, httpClient *http.Client, feedIDs
if err != nil {
c.config.logInfo("client: failed to connect to origin %s: %s", origins[x], err)
errs = append(errs, fmt.Errorf("origin %s: %w", origins[x], err))
// Retry connecting to the origin in the background
Copy link
Collaborator

@akuzni2 akuzni2 Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved so bugfix could be published - but if time permits I'd like to see if we could simplify this to avoid having 2 initialization points for conn create, monitoring, tracking. Something like this might work

var wg sync.WaitGroup
errs := make([]error, len(origins))

for x := 0; x < len(origins); x++ {
    wg.Add(1)
    go func(i int) {
        defer wg.Done()
        conn, err := s.newWSconnWithRetry(origins[i])
        if err != nil {
            errs[i] = fmt.Errorf("origin %s: %w", origins[i], err)
            return
        }
        s.conns = append(s.conns, conn) // Note: may need a mutex
        go s.monitorConn(conn)
    }(x)
}
wg.Wait()

// Only fail if we couldn't connect to ANY origins
if len(s.conns) == 0 {
    return nil, fmt.Errorf("failed to connect to any origins in HA mode: %v", errs)
}
// ...

I believe if neither origin is reachable then this logic will make sure that newWSconnWithRetry returns with an error and we can fail newStream

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I do agree I think just using the new connection method with retries would simplify things. I decided to keep both so that the newStream fails quickly if both origins aren't reachable to provide immediate feedback rather than going through all the retries ( 5-30 seconds).

I'd like to get this fix out though so I'll go ahead and merge the PR. I'm happy to discuss options here and make some further improvements though.

go func() {
conn, err := s.newWSconnWithRetry(origins[x])
if err != nil {
return
}
go s.monitorConn(conn)
s.conns = append(s.conns, conn)
}()
continue
}
go s.monitorConn(conn)
Expand All @@ -138,7 +147,10 @@ func (c *client) newStream(ctx context.Context, httpClient *http.Client, feedIDs

// Only fail if we couldn't connect to ANY origins
if len(s.conns) == 0 {
return nil, fmt.Errorf("failed to connect to any origins in HA mode: %v", errs)
err = fmt.Errorf("failed to connect to any origins in HA mode: %v", errs)
s.closeError.CompareAndSwap(nil, err)
s.Close()
return nil, err
}
c.config.logInfo("client: connected to %d out of %d origins in HA mode", len(s.conns), len(origins))
} else {
Expand Down Expand Up @@ -237,52 +249,52 @@ func (s *stream) monitorConn(conn *wsConn) {
// ensure the current connection is closed
_ = conn.close()

// reconnect loop
// will try to reconnect until client is closed or
// we have no active connections and have exceeded maxWSReconnectAttempts
var attempts int
for {
var re *wsConn
var err error

if s.closed.Load() {
return
}
re, err := s.newWSconnWithRetry(conn.origin)
if err != nil {
s.closeError.CompareAndSwap(nil, fmt.Errorf("stream has no active connections, last error: %w", err))
s.Close()
return
}
conn.replace(re.conn)
s.config.logInfo(
"client: stream websocket %s: reconnected",
conn.origin,
)
}
}

// fail the stream if we are over the maxWSReconnectAttempts
// and there are no other active connection
if attempts >= s.config.WsMaxReconnect && s.stats.activeConnections.Load() == 0 {
s.closeError.CompareAndSwap(nil, fmt.Errorf("stream has no active connections, last error: %w", err))
s.Close()
return
}
attempts++
func (s *stream) newWSconnWithRetry(origin string) (conn *wsConn, err error) {
// reconnect loop
// will try to reconnect until client is closed or
// we have no active connections and have exceeded maxWSReconnectAttempts
var attempts int
for {
if s.closed.Load() || s.streamCtx.Err() != nil {
return nil, fmt.Errorf("Retry cancelled, stream is closed")
}

ctx, cancel = context.WithTimeout(context.Background(), defaultWSConnectTimeout)
re, err = s.newWSconn(ctx, conn.origin)
cancel()
// fail the stream if we are over the maxWSReconnectAttempts
// and there are no other active connection
if attempts >= s.config.WsMaxReconnect && s.stats.activeConnections.Load() == 0 {
return nil, err
}
attempts++

if err != nil {
interval := time.Millisecond * time.Duration(
rand.Intn(maxWSReconnectIntervalMIllis-minWSReconnectIntervalMillis)+minWSReconnectIntervalMillis) //nolint:gosec
s.config.logInfo(
"client: stream websocket %s: error reconnecting: %s, backing off: %s",
conn.origin, err, interval.String(),
)
time.Sleep(interval)
continue
}
ctx, cancel := context.WithTimeout(context.Background(), defaultWSConnectTimeout)
conn, err = s.newWSconn(ctx, origin)
cancel()

conn.replace(re.conn)
if s.connStatusCallback != nil {
go s.connStatusCallback(true, conn.host, conn.origin)
}
if err != nil {
interval := time.Millisecond * time.Duration(
rand.Intn(maxWSReconnectIntervalMIllis-minWSReconnectIntervalMillis)+minWSReconnectIntervalMillis) //nolint:gosec
s.config.logInfo(
"client: stream websocket %s: reconnected",
conn.origin,
"client: stream websocket %s: error reconnecting: %s, backing off: %s",
origin, err, interval.String(),
)
break
time.Sleep(interval)
continue
}
return conn, nil
}
}

Expand Down
93 changes: 93 additions & 0 deletions go/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,96 @@ func TestClient_StreamHA_OneOriginDown(t *testing.T) {
}

}

// Tests that when in HA mode both origins are up after a recovery period even if one origin is down on initial connection
func TestClient_StreamHA_OneOriginDownRecovery(t *testing.T) {
connectAttempts := &atomic.Uint64{}
reconnectAttemptsBeforeRecovery := uint64(4)

ms := newMockServer(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead {
w.Header().Add(cllAvailOriginsHeader, "{001,002}")
w.WriteHeader(200)
return
}

if r.URL.Path != apiV1WS {
t.Errorf("expected path %s, got %s", apiV1WS, r.URL.Path)
}

origin := r.Header.Get(cllOriginHeader)
connectAttempts.Add(1)

// Simulate origin 002 being down for the first reconnectAttemptsBeforeRecovery attempts
// Add one to count for 001 connection
if origin == "002" && connectAttempts.Load() <= reconnectAttemptsBeforeRecovery+1 {
w.WriteHeader(http.StatusGatewayTimeout)
return
}

conn, err := websocket.Accept(
w, r, &websocket.AcceptOptions{CompressionMode: websocket.CompressionContextTakeover},
)

if err != nil {
t.Fatalf("error accepting connection: %s", err)
}
defer func() { _ = conn.CloseNow() }()

// Keep the connection alive for testing
for {
_, _, err := conn.Read(context.Background())
if err != nil {
break
}
}
})
defer ms.Close()

streamsClient, err := ms.Client()
if err != nil {
t.Fatalf("error creating client %s", err)
}

cc := streamsClient.(*client)
cc.config.Logger = LogPrintf
cc.config.LogDebug = true
cc.config.WsHA = true

sub, err := streamsClient.StreamWithStatusCallback(context.Background(), []feed.ID{feed1, feed2}, func(connected bool, host string, origin string) {
t.Logf("status callback: connected=%v, host=%s, origin=%s", connected, host, origin)
})
if err != nil {
t.Fatalf("error subscribing %s", err)
}
defer sub.Close()

for connectAttempts.Load() != 2 {
time.Sleep(time.Millisecond)
}

time.Sleep(time.Millisecond * 5)
stats := sub.Stats()
if stats.ActiveConnections != 1 {
t.Errorf("expected 1 active connection before recovery, got %d", stats.ActiveConnections)
}

if stats.ConfiguredConnections != 2 {
t.Errorf("expected 2 configured connections before recovery, got %d", stats.ConfiguredConnections)
}

// Add two to count one for 001 connection and one for 002 connection
for connectAttempts.Load() != reconnectAttemptsBeforeRecovery+2 {
time.Sleep(time.Millisecond)
}

time.Sleep(time.Millisecond * 5)
stats = sub.Stats()
if stats.ActiveConnections != 2 {
t.Errorf("expected 2 active connection after recovery, got %d", stats.ActiveConnections)
}

if stats.ConfiguredConnections != 2 {
t.Errorf("expected 2 configured connections after recovery, got %d", stats.ConfiguredConnections)
}
}