diff --git a/go/stream.go b/go/stream.go index cf948ae..9f38f13 100644 --- a/go/stream.go +++ b/go/stream.go @@ -130,6 +130,15 @@ func (c *client) newStream(ctx context.Context, httpClient *http.Client, feedIDs if err != nil { c.config.logInfo("client: failed to connect to origin %s: %s", origins[x], err) errs = append(errs, fmt.Errorf("origin %s: %w", origins[x], err)) + // Retry connecting to the origin in the background + go func() { + conn, err := s.newWSconnWithRetry(origins[x]) + if err != nil { + return + } + go s.monitorConn(conn) + s.conns = append(s.conns, conn) + }() continue } go s.monitorConn(conn) @@ -138,7 +147,10 @@ func (c *client) newStream(ctx context.Context, httpClient *http.Client, feedIDs // Only fail if we couldn't connect to ANY origins if len(s.conns) == 0 { - return nil, fmt.Errorf("failed to connect to any origins in HA mode: %v", errs) + err = fmt.Errorf("failed to connect to any origins in HA mode: %v", errs) + s.closeError.CompareAndSwap(nil, err) + s.Close() + return nil, err } c.config.logInfo("client: connected to %d out of %d origins in HA mode", len(s.conns), len(origins)) } else { @@ -237,52 +249,52 @@ func (s *stream) monitorConn(conn *wsConn) { // ensure the current connection is closed _ = conn.close() - // reconnect loop - // will try to reconnect until client is closed or - // we have no active connections and have exceeded maxWSReconnectAttempts - var attempts int - for { - var re *wsConn - var err error - - if s.closed.Load() { - return - } + re, err := s.newWSconnWithRetry(conn.origin) + if err != nil { + s.closeError.CompareAndSwap(nil, fmt.Errorf("stream has no active connections, last error: %w", err)) + s.Close() + return + } + conn.replace(re.conn) + s.config.logInfo( + "client: stream websocket %s: reconnected", + conn.origin, + ) + } +} - // fail the stream if we are over the maxWSReconnectAttempts - // and there are no other active connection - if attempts >= s.config.WsMaxReconnect && s.stats.activeConnections.Load() == 0 { - s.closeError.CompareAndSwap(nil, fmt.Errorf("stream has no active connections, last error: %w", err)) - s.Close() - return - } - attempts++ +func (s *stream) newWSconnWithRetry(origin string) (conn *wsConn, err error) { + // reconnect loop + // will try to reconnect until client is closed or + // we have no active connections and have exceeded maxWSReconnectAttempts + var attempts int + for { + if s.closed.Load() || s.streamCtx.Err() != nil { + return nil, fmt.Errorf("Retry cancelled, stream is closed") + } - ctx, cancel = context.WithTimeout(context.Background(), defaultWSConnectTimeout) - re, err = s.newWSconn(ctx, conn.origin) - cancel() + // fail the stream if we are over the maxWSReconnectAttempts + // and there are no other active connection + if attempts >= s.config.WsMaxReconnect && s.stats.activeConnections.Load() == 0 { + return nil, err + } + attempts++ - if err != nil { - interval := time.Millisecond * time.Duration( - rand.Intn(maxWSReconnectIntervalMIllis-minWSReconnectIntervalMillis)+minWSReconnectIntervalMillis) //nolint:gosec - s.config.logInfo( - "client: stream websocket %s: error reconnecting: %s, backing off: %s", - conn.origin, err, interval.String(), - ) - time.Sleep(interval) - continue - } + ctx, cancel := context.WithTimeout(context.Background(), defaultWSConnectTimeout) + conn, err = s.newWSconn(ctx, origin) + cancel() - conn.replace(re.conn) - if s.connStatusCallback != nil { - go s.connStatusCallback(true, conn.host, conn.origin) - } + if err != nil { + interval := time.Millisecond * time.Duration( + rand.Intn(maxWSReconnectIntervalMIllis-minWSReconnectIntervalMillis)+minWSReconnectIntervalMillis) //nolint:gosec s.config.logInfo( - "client: stream websocket %s: reconnected", - conn.origin, + "client: stream websocket %s: error reconnecting: %s, backing off: %s", + origin, err, interval.String(), ) - break + time.Sleep(interval) + continue } + return conn, nil } } diff --git a/go/stream_test.go b/go/stream_test.go index f9be3c0..2406ab0 100644 --- a/go/stream_test.go +++ b/go/stream_test.go @@ -801,3 +801,96 @@ func TestClient_StreamHA_OneOriginDown(t *testing.T) { } } + +// Tests that when in HA mode both origins are up after a recovery period even if one origin is down on initial connection +func TestClient_StreamHA_OneOriginDownRecovery(t *testing.T) { + connectAttempts := &atomic.Uint64{} + reconnectAttemptsBeforeRecovery := uint64(4) + + ms := newMockServer(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodHead { + w.Header().Add(cllAvailOriginsHeader, "{001,002}") + w.WriteHeader(200) + return + } + + if r.URL.Path != apiV1WS { + t.Errorf("expected path %s, got %s", apiV1WS, r.URL.Path) + } + + origin := r.Header.Get(cllOriginHeader) + connectAttempts.Add(1) + + // Simulate origin 002 being down for the first reconnectAttemptsBeforeRecovery attempts + // Add one to count for 001 connection + if origin == "002" && connectAttempts.Load() <= reconnectAttemptsBeforeRecovery+1 { + w.WriteHeader(http.StatusGatewayTimeout) + return + } + + conn, err := websocket.Accept( + w, r, &websocket.AcceptOptions{CompressionMode: websocket.CompressionContextTakeover}, + ) + + if err != nil { + t.Fatalf("error accepting connection: %s", err) + } + defer func() { _ = conn.CloseNow() }() + + // Keep the connection alive for testing + for { + _, _, err := conn.Read(context.Background()) + if err != nil { + break + } + } + }) + defer ms.Close() + + streamsClient, err := ms.Client() + if err != nil { + t.Fatalf("error creating client %s", err) + } + + cc := streamsClient.(*client) + cc.config.Logger = LogPrintf + cc.config.LogDebug = true + cc.config.WsHA = true + + sub, err := streamsClient.StreamWithStatusCallback(context.Background(), []feed.ID{feed1, feed2}, func(connected bool, host string, origin string) { + t.Logf("status callback: connected=%v, host=%s, origin=%s", connected, host, origin) + }) + if err != nil { + t.Fatalf("error subscribing %s", err) + } + defer sub.Close() + + for connectAttempts.Load() != 2 { + time.Sleep(time.Millisecond) + } + + time.Sleep(time.Millisecond * 5) + stats := sub.Stats() + if stats.ActiveConnections != 1 { + t.Errorf("expected 1 active connection before recovery, got %d", stats.ActiveConnections) + } + + if stats.ConfiguredConnections != 2 { + t.Errorf("expected 2 configured connections before recovery, got %d", stats.ConfiguredConnections) + } + + // Add two to count one for 001 connection and one for 002 connection + for connectAttempts.Load() != reconnectAttemptsBeforeRecovery+2 { + time.Sleep(time.Millisecond) + } + + time.Sleep(time.Millisecond * 5) + stats = sub.Stats() + if stats.ActiveConnections != 2 { + t.Errorf("expected 2 active connection after recovery, got %d", stats.ActiveConnections) + } + + if stats.ConfiguredConnections != 2 { + t.Errorf("expected 2 configured connections after recovery, got %d", stats.ConfiguredConnections) + } +}