Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 144 additions & 4 deletions pkg/fetch/celestia_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,35 @@ type CelestiaNodeFetcher struct {
blob blobAPI
headerCloser jsonrpc.ClientCloser
blobCloser jsonrpc.ClientCloser
addr string // original address for creating WS subscription clients
authHeader http.Header
log zerolog.Logger
mu sync.Mutex
closed bool
subCloser jsonrpc.ClientCloser // WS client for subscriptions, if any
}

const (
defaultRPCTimeout = 8 * time.Second
defaultRPCMaxRetries = 2
defaultRPCRetryDelay = 100 * time.Millisecond
wsDialTimeout = 10 * time.Second
)

// NewCelestiaNodeFetcher connects to a Celestia node at the given WebSocket address.
// NewCelestiaNodeFetcher connects to a Celestia node at the given address.
// Regular RPC calls use the provided URL scheme (typically HTTP).
// Subscriptions automatically upgrade to WebSocket when needed.
func NewCelestiaNodeFetcher(ctx context.Context, addr, token string, log zerolog.Logger) (*CelestiaNodeFetcher, error) {
headers := http.Header{}
if token != "" {
headers.Set("Authorization", "Bearer "+token)
}

f := &CelestiaNodeFetcher{log: log}
f := &CelestiaNodeFetcher{
addr: addr,
authHeader: headers,
log: log,
}

var err error
f.headerCloser, err = jsonrpc.NewClient(ctx, addr, "header", &f.header, headers)
Expand All @@ -76,6 +86,19 @@ func NewCelestiaNodeFetcher(ctx context.Context, addr, token string, log zerolog
return f, nil
}

// httpToWS converts http:// to ws:// and https:// to wss://.
// Returns the address unchanged if it already uses a WS scheme.
func httpToWS(addr string) string {
switch {
case strings.HasPrefix(addr, "http://"):
return "ws://" + strings.TrimPrefix(addr, "http://")
case strings.HasPrefix(addr, "https://"):
return "wss://" + strings.TrimPrefix(addr, "https://")
default:
return addr
}
}

func (f *CelestiaNodeFetcher) GetHeader(ctx context.Context, height uint64) (*types.Header, error) {
raw, err := f.callRawWithRetry(ctx, "header.GetByHeight", func(callCtx context.Context) (json.RawMessage, error) {
return f.header.GetByHeight(callCtx, height)
Expand Down Expand Up @@ -147,11 +170,82 @@ func (f *CelestiaNodeFetcher) callRawWithRetry(ctx context.Context, op string, f
}

func (f *CelestiaNodeFetcher) SubscribeHeaders(ctx context.Context) (<-chan *types.Header, error) {
// Try subscription on the existing client first (works if already on WS).
rawCh, err := f.header.Subscribe(ctx)
if err != nil {
return nil, fmt.Errorf("header.Subscribe: %w", err)
// The client is likely HTTP — upgrade to WS for subscriptions.
rawCh, err = f.subscribeViaWS(ctx)
}
if err != nil {
// Neither worked — fall back to polling.
f.log.Warn().Err(err).Msg("header.Subscribe not available, falling back to polling")
return f.pollHeaders(ctx), nil
}

return f.forwardHeaders(ctx, rawCh), nil
}

// wsSubscribeResult holds the outcome of a WS subscribe attempt.
type wsSubscribeResult struct {
ch <-chan json.RawMessage
closer jsonrpc.ClientCloser
err error
}

// subscribeViaWS creates a separate WebSocket client for header subscriptions.
// This handles the case where the main client uses HTTP (no channel support).
// The connection attempt is bounded by wsDialTimeout; if the node doesn't
// support WebSocket the goroutine is abandoned (cleaned up when ctx ends).
func (f *CelestiaNodeFetcher) subscribeViaWS(ctx context.Context) (<-chan json.RawMessage, error) {
wsAddr := httpToWS(f.addr)
if wsAddr == f.addr {
return nil, fmt.Errorf("address %q is not HTTP; cannot upgrade to WebSocket", f.addr)
}

f.log.Info().Str("ws_addr", wsAddr).Msg("upgrading to WebSocket for header subscription")

// Run the WS dial + subscribe in a goroutine so we can timeout if the
// node doesn't accept WebSocket connections. The parent ctx is passed to
// NewClient because it controls the WS connection lifetime (not just dial).
done := make(chan wsSubscribeResult, 1)
go func() {
var subAPI headerAPI
closer, err := jsonrpc.NewClient(ctx, wsAddr, "header", &subAPI, f.authHeader)
if err != nil {
done <- wsSubscribeResult{err: fmt.Errorf("connect WS header client: %w", err)}
return
}
ch, err := subAPI.Subscribe(ctx)
if err != nil {
closer()
done <- wsSubscribeResult{err: fmt.Errorf("header.Subscribe via WS: %w", err)}
return
}
done <- wsSubscribeResult{ch: ch, closer: closer}
}()

select {
case r := <-done:
if r.err != nil {
return nil, r.err
}
f.mu.Lock()
old := f.subCloser
f.subCloser = r.closer
f.mu.Unlock()
if old != nil {
old()
}
return r.ch, nil
case <-time.After(wsDialTimeout):
return nil, fmt.Errorf("WS connection to %s timed out after %s", wsAddr, wsDialTimeout)
case <-ctx.Done():
return nil, ctx.Err()
}
Comment on lines +210 to 244
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

WS connection leaks when the dial-timeout fires but the goroutine subsequently succeeds.

done is a chan wsSubscribeResult of capacity 1. When time.After(wsDialTimeout) fires and subscribeViaWS returns an error (falling back to polling), the goroutine is still alive with the parent ctx. If it later manages to dial and subscribe it writes {ch: ch, closer: closer} into the now-unread buffer. Nobody drains done after the timeout branch exits, so:

  1. closer is never called — the WS client stays open until ctx ends.
  2. The go-jsonrpc internal goroutine that feeds ch runs indefinitely with no consumer.
🛡️ Proposed fix: drain and close in a background goroutine on timeout
 	select {
 	case r := <-done:
 		if r.err != nil {
 			return nil, r.err
 		}
 		f.mu.Lock()
 		old := f.subCloser
 		f.subCloser = r.closer
 		f.mu.Unlock()
 		if old != nil {
 			old()
 		}
 		return r.ch, nil
 	case <-time.After(wsDialTimeout):
+		// Drain the channel in the background and clean up if the goroutine
+		// eventually succeeds, to avoid leaking the WS connection/subscription.
+		go func() {
+			if r := <-done; r.closer != nil {
+				r.closer()
+			}
+		}()
 		return nil, fmt.Errorf("WS connection to %s timed out after %s", wsAddr, wsDialTimeout)
 	case <-ctx.Done():
 		return nil, ctx.Err()
 	}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@pkg/fetch/celestia_node.go` around lines 210 - 244, The goroutine started to
dial/subscribe can succeed after the select times out and will try to send a
wsSubscribeResult into the buffered done channel, leaking the WS client and
internal goroutines; fix by spawning a background goroutine in the timeout
branch that drains done and invokes the returned closer if any (and otherwise
discards errors) so resources are cleaned up. Specifically, in the timeout case
where the select returns the time.After branch, start a goroutine like: go
func() { r := <-done; if r.closer != nil { r.closer() } }() so that results from
jsonrpc.NewClient / subAPI.Subscribe are consumed and closer is called; keep
existing logic that sets f.subCloser and calls old() in the successful select
case unchanged.

}

// forwardHeaders maps raw JSON headers from a subscription channel to typed headers.
func (f *CelestiaNodeFetcher) forwardHeaders(ctx context.Context, rawCh <-chan json.RawMessage) <-chan *types.Header {
out := make(chan *types.Header, 64)
go func() {
defer close(out)
Expand All @@ -176,8 +270,51 @@ func (f *CelestiaNodeFetcher) SubscribeHeaders(ctx context.Context) (<-chan *typ
}
}
}()
return out
}

// pollHeaders polls GetNetworkHead at 1s intervals, emitting new headers when
// the height advances. Used as a fallback when header.Subscribe is unavailable.
// NOTE: only the current chain tip is emitted; intermediate heights produced
// between ticks are skipped. The sync coordinator handles this via gap detection
// and re-backfill, so no data is lost — but this path is higher latency than
// a true subscription.
func (f *CelestiaNodeFetcher) pollHeaders(ctx context.Context) <-chan *types.Header {
out := make(chan *types.Header, 64)
go func() {
defer close(out)

ticker := time.NewTicker(time.Second)
defer ticker.Stop()

var lastHeight uint64

return out, nil
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
hdr, err := f.GetNetworkHead(ctx)
if err != nil {
if ctx.Err() != nil {
return
}
f.log.Warn().Err(err).Msg("poll network head failed")
continue
}
if hdr.Height <= lastHeight {
continue
}
lastHeight = hdr.Height
select {
case out <- hdr:
case <-ctx.Done():
return
}
}
}
}()
return out
}

// GetProof forwards a blob proof request to the upstream Celestia node.
Expand Down Expand Up @@ -210,6 +347,9 @@ func (f *CelestiaNodeFetcher) Close() error {
return nil
}
f.closed = true
if f.subCloser != nil {
f.subCloser()
}
f.headerCloser()
f.blobCloser()
return nil
Expand Down
14 changes: 14 additions & 0 deletions pkg/sync/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"time"

"github.com/rs/zerolog"

Expand All @@ -12,6 +13,8 @@ import (
"github.com/evstack/apex/pkg/types"
)

const streamingLogInterval = 30 * time.Second

// SubscriptionManager processes new headers from a live subscription.
type SubscriptionManager struct {
store store.Store
Expand Down Expand Up @@ -45,10 +48,20 @@ func (sm *SubscriptionManager) Run(ctx context.Context) error {
networkHeight = ss.NetworkHeight
}

ticker := time.NewTicker(streamingLogInterval)
defer ticker.Stop()
var processed uint64

for {
select {
case <-ctx.Done():
return nil
case <-ticker.C:
sm.log.Info().
Uint64("height", lastHeight).
Uint64("blocks", processed).
Msg("streaming progress")
processed = 0
case hdr, ok := <-ch:
if !ok {
// Channel closed (disconnect or ctx cancelled).
Expand All @@ -72,6 +85,7 @@ func (sm *SubscriptionManager) Run(ctx context.Context) error {
}

lastHeight = hdr.Height
processed++
}
}
}
Expand Down