diff --git a/packages/api/internal/clusters/cluster.go b/packages/api/internal/clusters/cluster.go index 6dae893cdf..7bd5516ac2 100644 --- a/packages/api/internal/clusters/cluster.go +++ b/packages/api/internal/clusters/cluster.go @@ -263,3 +263,10 @@ func (c *Cluster) GetOrchestrators() []*Instance { func (c *Cluster) GetResources() ClusterResource { return c.resources } + +// SyncInstances performs an immediate synchronization of cluster instances from +// the service discovery source. It is called on-demand when a node lookup fails, +// to handle newly joined orchestrators that may not yet be in the in-memory pool. +func (c *Cluster) SyncInstances(ctx context.Context) error { + return c.synchronization.Sync(ctx) +} diff --git a/packages/api/internal/orchestrator/client.go b/packages/api/internal/orchestrator/client.go index 9690d6c25f..918680c77e 100644 --- a/packages/api/internal/orchestrator/client.go +++ b/packages/api/internal/orchestrator/client.go @@ -3,6 +3,7 @@ package orchestrator import ( "context" "fmt" + "sync" "time" "github.com/google/uuid" @@ -21,28 +22,60 @@ func (o *Orchestrator) connectToNode(ctx context.Context, discovered nodemanager ctx, childSpan := tracer.Start(ctx, "connect-to-node") defer childSpan.End() - orchestratorNode, err := nodemanager.New(ctx, o.tel.TracerProvider, o.tel.MeterProvider, discovered) - if err != nil { - return err - } + _, err, _ := o.connectGroup.Do(discovered.NomadNodeShortID, func() (any, error) { + // Re-check inside the singleflight to prevent race issues due to overwriting existing nodes in the map + if o.GetNodeByNomadShortID(discovered.NomadNodeShortID) != nil { + return nil, nil + } - // Update host metrics from service info - o.registerNode(orchestratorNode) + connectCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), nodeConnectTimeout) + defer cancel() - return nil + orchestratorNode, err := nodemanager.New(connectCtx, o.tel.TracerProvider, o.tel.MeterProvider, discovered) + if err != nil { + return nil, err + } + + o.registerNode(orchestratorNode) + + return nil, nil + }) + + return err } func (o *Orchestrator) connectToClusterNode(ctx context.Context, cluster *clusters.Cluster, i *clusters.Instance) { - orchestratorNode, err := nodemanager.NewClusterNode(ctx, i.GetClient(), cluster.ID, cluster.SandboxDomain, i) - if err != nil { - logger.L().Error(ctx, "Failed to create node", zap.Error(err)) + ctx, span := tracer.Start(ctx, "connect-to-cluster-node") + defer span.End() - return - } + // connectGroup is keyed by scopedNodeID so that concurrent callers targeting + // the same cluster instance share a single dial attempt. + scopedKey := o.scopedNodeID(cluster.ID, i.NodeID) + + o.connectGroup.Do(scopedKey, func() (any, error) { //nolint:errcheck + // Re-check inside the singleflight for the same reason as connectToNode. + if o.GetNode(cluster.ID, i.NodeID) != nil { + return nil, nil + } + + connectCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), nodeConnectTimeout) + defer cancel() + + orchestratorNode, err := nodemanager.NewClusterNode(connectCtx, i.GetClient(), cluster.ID, cluster.SandboxDomain, i) + if err != nil { + logger.L().Error(ctx, "Failed to create node", zap.Error(err)) - o.registerNode(orchestratorNode) + return nil, nil + } + + o.registerNode(orchestratorNode) + + return nil, nil + }) } +// registerNode adds the given node to the in-memory map of nodes +// It has to be called only once per node func (o *Orchestrator) registerNode(node *nodemanager.Node) { scopedKey := o.scopedNodeID(node.ClusterID, node.ID) o.nodes.Insert(scopedKey, node) @@ -94,6 +127,114 @@ func (o *Orchestrator) GetNode(clusterID uuid.UUID, nodeID string) *nodemanager. return n } +// getOrConnectNode returns a node from the in-memory cache. When the node is absent it +// performs a targeted on-demand discovery and connection attempt, handling the race +// condition where a new orchestrator joined the cluster after this API instance's last +// sync cycle but another API instance already routed a sandbox there. +// +// There are two distinct gaps that must be covered: +// - Gap 1 (0–5 s for clusters, 0–20 s for Nomad): the node exists in the upstream +// source (Nomad / remote service discovery) but has not yet been pulled into the +// local instance map by the background sync loop. +// - Gap 2 (0–20 s): the node is in the local instance map but has not yet been +// promoted into o.nodes by keepInSync. +// +// discoveryGroup ensures that concurrent requests targeting the same missing +// node share a single discovery attempt rather than fanning out. +func (o *Orchestrator) getOrConnectNode(ctx context.Context, clusterID uuid.UUID, nodeID string) *nodemanager.Node { + ctx, span := tracer.Start(ctx, "get-or-connect-node") + defer span.End() + + if node := o.GetNode(clusterID, nodeID); node != nil { + return node + } + + logger.L().Warn(ctx, "Node not found in cache, attempting on-demand connection", + logger.WithNodeID(nodeID), + zap.String("cluster_id", clusterID.String()), + ) + + scopedKey := o.scopedNodeID(clusterID, nodeID) + + o.discoveryGroup.Do(scopedKey, func() (any, error) { //nolint:errcheck + // Re-check inside the singleflight + if node := o.GetNode(clusterID, nodeID); node != nil { + return nil, nil + } + + connectCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cacheSyncTime) + defer cancel() + + if clusterID == consts.LocalClusterID { + o.discoverNomadNodes(connectCtx) + } else { + o.discoverClusterNode(connectCtx, clusterID) + } + + return nil, nil + }) + + return o.GetNode(clusterID, nodeID) +} + +// discoverNomadNodes lists all ready Nomad nodes and connects any that are not yet in the pool. +// Once a new node is connected its orchestrator ID becomes the map key, making subsequent GetNode calls succeed. +func (o *Orchestrator) discoverNomadNodes(ctx context.Context) { + ctx, span := tracer.Start(ctx, "discover-nomad-nodes") + defer span.End() + + nomadNodes, err := o.listNomadNodes(ctx) + if err != nil { + logger.L().Error(ctx, "Error listing Nomad nodes during on-demand discovery", zap.Error(err)) + + return + } + + var wg sync.WaitGroup + defer wg.Wait() + + for _, n := range nomadNodes { + if o.GetNodeByNomadShortID(n.NomadNodeShortID) == nil { + wg.Go(func() { + if err := o.connectToNode(ctx, n); err != nil { + logger.L().Error(ctx, "Error connecting to Nomad node on demand", + zap.Error(err), zap.String("nomad_short_id", n.NomadNodeShortID)) + } + }) + } + } +} + +// discoverClusterNode forces a fresh service discovery query so that nodes which joined after the +// last periodic sync are pulled into cluster.instances, then opportunistically connects all +// unknown nodes into o.nodes (not just the target), avoiding repeated on-demand discoveries. +func (o *Orchestrator) discoverClusterNode(ctx context.Context, clusterID uuid.UUID) { + ctx, span := tracer.Start(ctx, "discover-cluster-node") + defer span.End() + + cluster, found := o.clusters.GetClusterById(clusterID) + if !found { + logger.L().Error(ctx, "Cluster not found during on-demand node discovery", logger.WithClusterID(clusterID)) + + return + } + + if err := cluster.SyncInstances(ctx); err != nil { + logger.L().Error(ctx, "Error syncing cluster instances during on-demand node discovery", zap.Error(err), logger.WithClusterID(clusterID)) + + return + } + + var wg sync.WaitGroup + defer wg.Wait() + + for _, instance := range cluster.GetOrchestrators() { + wg.Go(func() { + o.connectToClusterNode(ctx, cluster, instance) + }) + } +} + func (o *Orchestrator) GetClusterNodes(clusterID uuid.UUID) []*nodemanager.Node { clusterNodes := make([]*nodemanager.Node, 0) for _, n := range o.nodes.Items() { diff --git a/packages/api/internal/orchestrator/client_test.go b/packages/api/internal/orchestrator/client_test.go new file mode 100644 index 0000000000..d0ae9e4744 --- /dev/null +++ b/packages/api/internal/orchestrator/client_test.go @@ -0,0 +1,337 @@ +package orchestrator + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + nomadapi "github.com/hashicorp/nomad/api" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/e2b-dev/infra/packages/api/internal/api" + "github.com/e2b-dev/infra/packages/api/internal/orchestrator/nodemanager" + "github.com/e2b-dev/infra/packages/shared/pkg/consts" + infogrpc "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator-info" + "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/smap" + "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" +) + +// newTestOrchestrator creates a minimal Orchestrator with only the fields +// needed for node lookup / discovery tests. Production fields (sandboxStore, +// analytics, redis, etc.) are left nil because the code paths under test never +// touch them. +func newTestOrchestrator(t *testing.T, nomad *nomadapi.Client) *Orchestrator { + t.Helper() + + ctx := t.Context() + logger.ReplaceGlobals(ctx, logger.NewNopLogger()) + + return &Orchestrator{ + nodes: smap.New[*nodemanager.Node](), + nomadClient: nomad, + tel: telemetry.NewNoopClient(), + } +} + +func newNomadMock(t *testing.T, handler http.HandlerFunc) *nomadapi.Client { + t.Helper() + + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + + client, err := nomadapi.NewClient(&nomadapi.Config{Address: srv.URL}) + require.NoError(t, err) + + return client +} + +// fakeInfoServer implements the minimum InfoServiceServer surface needed by +// nodemanager.New: it responds to ServiceInfo with a canned response carrying +// the given nodeID and Healthy status. +type fakeInfoServer struct { + infogrpc.UnimplementedInfoServiceServer + + nodeID string +} + +func (s *fakeInfoServer) ServiceInfo(context.Context, *emptypb.Empty) (*infogrpc.ServiceInfoResponse, error) { + return &infogrpc.ServiceInfoResponse{ + NodeId: s.nodeID, + ServiceId: "test-service-instance", + ServiceStatus: infogrpc.ServiceInfoStatus_Healthy, + MetricCpuCount: 4, + }, nil +} + +// startFakeOrchestratorGRPC starts a gRPC server that responds to ServiceInfo +// requests. When addr is empty it listens on an ephemeral port; otherwise it +// binds to the given address (e.g. "127.0.0.1:5008"). Returns the listener +// address. +func startFakeOrchestratorGRPC(t *testing.T, nodeID string, addr string) string { + t.Helper() + + if addr == "" { + addr = "127.0.0.1:0" + } + + lis, err := (&net.ListenConfig{}).Listen(t.Context(), "tcp", addr) + require.NoError(t, err) + + srv := grpc.NewServer(grpc.Creds(insecure.NewCredentials())) + infogrpc.RegisterInfoServiceServer(srv, &fakeInfoServer{nodeID: nodeID}) + + go srv.Serve(lis) + t.Cleanup(srv.GracefulStop) + + return lis.Addr().String() +} + +// TestGetOrConnectNode_CacheHit verifies that when a node is already in the +// cache, getOrConnectNode returns it immediately without triggering any +// discovery. This is the fast path exercised on every sandbox operation. +func TestGetOrConnectNode_CacheHit(t *testing.T) { + t.Parallel() + + o := newTestOrchestrator(t, nil) + + clusterID := uuid.New() + testNode := nodemanager.NewTestNode("node-1", api.NodeStatusReady, 3, 4) + testNode.ClusterID = clusterID + o.nodes.Insert(o.scopedNodeID(clusterID, "node-1"), testNode) + + got := o.getOrConnectNode(t.Context(), clusterID, "node-1") + require.NotNil(t, got) + assert.Equal(t, "node-1", got.ID) +} + +func TestGetOrConnectNode_CacheHit_LocalCluster(t *testing.T) { + t.Parallel() + + o := newTestOrchestrator(t, nil) + + testNode := nodemanager.NewTestNode("local-node", api.NodeStatusReady, 2, 4) + testNode.ClusterID = consts.LocalClusterID + o.nodes.Insert("local-node", testNode) + + got := o.getOrConnectNode(t.Context(), consts.LocalClusterID, "local-node") + require.NotNil(t, got) + assert.Equal(t, "local-node", got.ID) +} + +// TestGetOrConnectNode_CacheMiss_TriggersNomadDiscovery verifies that when a Nomad node is NOT +// in the cache, getOrConnectNode triggers on-demand Nomad service discovery rather than immediately returning nil. +// +// This handles scenario when a new orchestrator node has joined the cluster and may not be in cache yet +func TestGetOrConnectNode_CacheMiss_TriggersNomadDiscovery(t *testing.T) { + t.Parallel() + + var discoveryAttempts atomic.Int32 + + nomadClient := newNomadMock(t, func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/nodes" { + discoveryAttempts.Add(1) + + // Return a node stub. connectToNode will fail at the gRPC level + // (nodemanager.New dials the fake address), but the important thing + // is that discovery WAS attempted. + resp := []*nomadapi.NodeListStub{ + { + ID: "abcdef1234567890abcdef1234567890abcdef12", + Address: "127.0.0.1", + Status: "ready", + NodePool: "default", + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + + return + } + + http.NotFound(w, r) + }) + + o := newTestOrchestrator(t, nomadClient) + + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) + defer cancel() + + // Request a node that isn't in cache — should trigger discovery. + o.getOrConnectNode(ctx, consts.LocalClusterID, "nonexistent-node") + + // The node won't be found because connectToNode fails at the gRPC level, + // but discovery MUST have been attempted. + assert.Positive(t, discoveryAttempts.Load(), "expected on-demand Nomad discovery to be triggered") +} + +// TestGetOrConnectNode_ConcurrentCacheMiss_SharesDiscovery verifies that +// multiple concurrent getOrConnectNode calls for the same missing node share +// a single discovery attempt +func TestGetOrConnectNode_ConcurrentCacheMiss_SharesDiscovery(t *testing.T) { + t.Parallel() + + var discoveryAttempts atomic.Int32 + + nomadClient := newNomadMock(t, func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/nodes" { + discoveryAttempts.Add(1) + // Slow response to ensure concurrent callers overlap. + time.Sleep(100 * time.Millisecond) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode([]*nomadapi.NodeListStub{}) + + return + } + + http.NotFound(w, r) + }) + + o := newTestOrchestrator(t, nomadClient) + + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + // Fire 10 concurrent lookups for the same missing node. + const concurrency = 10 + done := make(chan struct{}, concurrency) + for range concurrency { + go func() { + defer func() { done <- struct{}{} }() + o.getOrConnectNode(ctx, consts.LocalClusterID, "missing-node") + }() + } + + for range concurrency { + <-done + } + + // should collapse all 10 calls into ≤2 + // one flight, possibly one more if the first completed before a late caller arrived + assert.LessOrEqual(t, discoveryAttempts.Load(), int32(2), + "singleflight should deduplicate concurrent discovery attempts") +} + +// TestConnectToNode_SingleflightDedup verifies that concurrent connectToNode +// calls for the same NomadNodeShortID share a single connection attempt +func TestConnectToNode_SingleflightDedup(t *testing.T) { + t.Parallel() + + o := newTestOrchestrator(t, nil) + + // grpc.NewClient is lazy — it returns immediately — and nodemanager.New + // then fails at the ServiceInfo RPC call + discovery := nodemanager.NomadServiceDiscovery{ + NomadNodeShortID: "abcdef12", + OrchestratorAddress: "127.0.0.1:1", + IPAddress: "127.0.0.1", + } + + const concurrency = 10 + errs := make(chan error, concurrency) + + for range concurrency { + go func() { + ctx, cancel := context.WithTimeout(t.Context(), 2*time.Second) + defer cancel() + errs <- o.connectToNode(ctx, discovery) + }() + } + + for range concurrency { + <-errs + } + + // After all calls complete, verify the node map has at most 1 entry + // (or 0 if all failed) + assert.LessOrEqual(t, o.nodes.Count(), 1, "singleflight should prevent duplicate registrations") +} + +// TestGetOrConnectNode_CacheMiss_DiscoversAndConnects is the end-to-end +// test for a race condition. It simulates following scenario: +// +// 1. A new orchestrator node is running and reachable (fake gRPC server). +// 2. Nomad service discovery knows about it (mock HTTP API). +// 3. This API instance has NOT yet synced (node is absent from o.nodes). +// 4. A handler calls getOrConnectNode for a sandbox on that node. +// +// The fake gRPC server listens on consts.OrchestratorAPIPort so that +// listNomadNodes builds the correct address (ip:OrchestratorAPIPort). +func TestGetOrConnectNode_CacheMiss_DiscoversAndConnects(t *testing.T) { + t.Parallel() + + orchestratorNodeID := "orch-node-42" + nomadFullID := "aabbccdd11223344aabbccdd11223344aabbccdd" + + // 1. Start a fake gRPC server on consts.OrchestratorAPIPort so that the + // address built by listNomadNodes matches our listener. + listenAddr := fmt.Sprintf("127.0.0.1:%d", consts.OrchestratorAPIPort) + startFakeOrchestratorGRPC(t, orchestratorNodeID, listenAddr) + + // 2. Mock Nomad HTTP API returning a single ready node at 127.0.0.1. + nomadClient := newNomadMock(t, func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v1/nodes" { + resp := []*nomadapi.NodeListStub{ + { + ID: nomadFullID, + Address: "127.0.0.1", + Status: "ready", + NodePool: "default", + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + + return + } + + http.NotFound(w, r) + }) + + o := newTestOrchestrator(t, nomadClient) + + // 3. Verify the node is NOT in cache. + assert.Nil(t, o.GetNode(consts.LocalClusterID, orchestratorNodeID)) + + // 4. getOrConnectNode triggers the full discovery path: + // cache miss → discoverNomadNode → listNomadNodes → connectToNode → gRPC ServiceInfo → registerNode + node := o.getOrConnectNode(t.Context(), consts.LocalClusterID, orchestratorNodeID) + require.NotNil(t, node, "getOrConnectNode must discover and connect the node via Nomad") + assert.Equal(t, orchestratorNodeID, node.ID) + assert.Equal(t, nomadFullID[:consts.NodeIDLength], node.NomadNodeShortID) +} + +// TestRegisterNode_NoDuplicates verifies that registerNode is idempotent +// when the same scoped key is used — the last write wins, and the map never +// grows beyond the number of unique keys. +func TestRegisterNode_NoDuplicates(t *testing.T) { + t.Parallel() + + o := newTestOrchestrator(t, nil) + + clusterID := uuid.New() + wg := sync.WaitGroup{} + for i := range 50 { + wg.Go(func() { + node := nodemanager.NewTestNode(fmt.Sprintf("node-%d", i%5), api.NodeStatusReady, 2, 4) + node.ClusterID = clusterID + o.registerNode(node) + }) + } + + wg.Wait() + assert.Equal(t, 5, o.nodes.Count()) +} diff --git a/packages/api/internal/orchestrator/delete_instance.go b/packages/api/internal/orchestrator/delete_instance.go index b9bf391b5c..611e9594fe 100644 --- a/packages/api/internal/orchestrator/delete_instance.go +++ b/packages/api/internal/orchestrator/delete_instance.go @@ -100,7 +100,7 @@ func (o *Orchestrator) removeSandboxFromNode(ctx context.Context, sbx sandbox.Sa ctx, span := tracer.Start(ctx, "remove-sandbox-from-node") defer span.End() - node := o.GetNode(sbx.ClusterID, sbx.NodeID) + node := o.getOrConnectNode(ctx, sbx.ClusterID, sbx.NodeID) if node == nil { logger.L().Error(ctx, "failed to get node", logger.WithNodeID(sbx.NodeID)) diff --git a/packages/api/internal/orchestrator/orchestrator.go b/packages/api/internal/orchestrator/orchestrator.go index b195c4b153..3bdc28c88b 100644 --- a/packages/api/internal/orchestrator/orchestrator.go +++ b/packages/api/internal/orchestrator/orchestrator.go @@ -11,6 +11,7 @@ import ( "github.com/redis/go-redis/v9" "go.opentelemetry.io/otel/metric" "go.uber.org/zap" + "golang.org/x/sync/singleflight" analyticscollector "github.com/e2b-dev/infra/packages/api/internal/analytics_collector" "github.com/e2b-dev/infra/packages/api/internal/cfg" @@ -66,6 +67,21 @@ type Orchestrator struct { snapshotCache SnapshotCacheInvalidator snapshotUpsertSem *utils.AdjustableSemaphore + + // connectGroup deduplicates concurrent dial+register attempts for the same + // physical node. It is keyed by NomadNodeShortID (Nomad-managed nodes) or + // scopedNodeID(clusterID, instanceNodeID) (cluster nodes) and is held inside + // connectToNode / connectToClusterNode, so it guards every connection path + // regardless of what triggered the attempt. + connectGroup singleflight.Group + + // discoveryGroup deduplicates concurrent on-demand discovery attempts in + // getOrConnectNode that target the same missing orchestrator node. It is + // intentionally separate from connectGroup to avoid a deadlock: for cluster + // nodes the outer discoveryGroup key and the inner connectGroup key are the + // same string, and nesting Do calls for the same key on the same Group would + // block forever. + discoveryGroup singleflight.Group } func New( diff --git a/packages/api/internal/orchestrator/snapshot_template.go b/packages/api/internal/orchestrator/snapshot_template.go index 600438d2b3..697784350d 100644 --- a/packages/api/internal/orchestrator/snapshot_template.go +++ b/packages/api/internal/orchestrator/snapshot_template.go @@ -63,7 +63,7 @@ func (o *Orchestrator) CreateSnapshotTemplate(ctx context.Context, teamID uuid.U } defer finish(nil) - node := o.GetNode(sbx.ClusterID, sbx.NodeID) + node := o.getOrConnectNode(ctx, sbx.ClusterID, sbx.NodeID) if node == nil { return SnapshotTemplateResult{}, fmt.Errorf("node '%s' not found", sbx.NodeID) } diff --git a/packages/api/internal/orchestrator/update_instance.go b/packages/api/internal/orchestrator/update_instance.go index f751432d15..a6ea40dc75 100644 --- a/packages/api/internal/orchestrator/update_instance.go +++ b/packages/api/internal/orchestrator/update_instance.go @@ -31,7 +31,12 @@ func (o *Orchestrator) UpdateSandbox( ) defer span.End() - client, ctx := o.GetNode(clusterID, nodeID).GetClient(ctx) + node := o.getOrConnectNode(ctx, clusterID, nodeID) + if node == nil { + return fmt.Errorf("node '%s' not found", nodeID) + } + + client, ctx := node.GetClient(ctx) _, err := client.Sandbox.Update( ctx, &orchestrator.SandboxUpdateRequest{ SandboxId: sandboxID, diff --git a/packages/api/internal/orchestrator/update_network.go b/packages/api/internal/orchestrator/update_network.go index 1d2ebab643..371a99182f 100644 --- a/packages/api/internal/orchestrator/update_network.go +++ b/packages/api/internal/orchestrator/update_network.go @@ -76,7 +76,7 @@ func (o *Orchestrator) updateSandboxNetworkOnNode( ) defer span.End() - node := o.GetNode(sbx.ClusterID, sbx.NodeID) + node := o.getOrConnectNode(ctx, sbx.ClusterID, sbx.NodeID) if node == nil { return &api.APIError{ Code: http.StatusInternalServerError, diff --git a/packages/shared/pkg/synchronization/synchronization.go b/packages/shared/pkg/synchronization/synchronization.go index c4a2e359e0..dee2b1ca57 100644 --- a/packages/shared/pkg/synchronization/synchronization.go +++ b/packages/shared/pkg/synchronization/synchronization.go @@ -8,6 +8,7 @@ import ( "go.opentelemetry.io/otel" "go.uber.org/zap" + "golang.org/x/sync/semaphore" "github.com/e2b-dev/infra/packages/shared/pkg/logger" ) @@ -35,6 +36,9 @@ type Synchronize[SourceItem any, PoolItem any] struct { cancel chan struct{} // channel for cancellation of synchronization cancelOnce sync.Once + + // syncSem prevents concurrent PoolInsert calls + syncSem *semaphore.Weighted } func NewSynchronize[SourceItem any, PoolItem any](spanPrefix string, logsPrefix string, store Store[SourceItem, PoolItem]) *Synchronize[SourceItem, PoolItem] { @@ -43,6 +47,7 @@ func NewSynchronize[SourceItem any, PoolItem any](spanPrefix string, logsPrefix logsPrefix: logsPrefix, store: store, cancel: make(chan struct{}), + syncSem: semaphore.NewWeighted(1), } return s @@ -51,7 +56,7 @@ func NewSynchronize[SourceItem any, PoolItem any](spanPrefix string, logsPrefix func (s *Synchronize[SourceItem, PoolItem]) Start(ctx context.Context, syncInterval time.Duration, syncRoundTimeout time.Duration, runInitialSync bool) { if runInitialSync { initialSyncTimeout, initialSyncCancel := context.WithTimeout(context.WithoutCancel(ctx), syncRoundTimeout) - err := s.sync(initialSyncTimeout) + err := s.Sync(initialSyncTimeout) initialSyncCancel() if err != nil { logger.L().Error(ctx, s.getLog("Initial sync failed"), zap.Error(err)) @@ -69,7 +74,7 @@ func (s *Synchronize[SourceItem, PoolItem]) Start(ctx context.Context, syncInter return case <-timer.C: syncTimeout, syncCancel := context.WithTimeout(context.WithoutCancel(ctx), syncRoundTimeout) - err := s.sync(syncTimeout) + err := s.Sync(syncTimeout) syncCancel() if err != nil { logger.L().Error(ctx, s.getLog("Failed to synchronize"), zap.Error(err)) @@ -84,10 +89,16 @@ func (s *Synchronize[SourceItem, PoolItem]) Close() { ) } -func (s *Synchronize[SourceItem, PoolItem]) sync(ctx context.Context) error { +// Sync performs periodic sync or it can be done as an on-demand synchronization round against the source. +func (s *Synchronize[SourceItem, PoolItem]) Sync(ctx context.Context) error { ctx, span := tracer.Start(ctx, s.getSpanName("sync-items")) defer span.End() + if err := s.syncSem.Acquire(ctx, 1); err != nil { + return fmt.Errorf("failed to acquire sync lock: %w", err) + } + defer s.syncSem.Release(1) + sourceItems, err := s.store.SourceList(ctx) if err != nil { return err diff --git a/packages/shared/pkg/synchronization/synchronization_test.go b/packages/shared/pkg/synchronization/synchronization_test.go index 98cf060347..853b9b40b8 100644 --- a/packages/shared/pkg/synchronization/synchronization_test.go +++ b/packages/shared/pkg/synchronization/synchronization_test.go @@ -5,6 +5,11 @@ import ( "slices" "sync" "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/semaphore" "github.com/e2b-dev/infra/packages/shared/pkg/logger" ) @@ -87,9 +92,64 @@ func newSynchronizer(ctx context.Context, store Store[string, string]) *Synchron store: store, tracerSpanPrefix: "test synchronization", logsPrefix: "test synchronization", + syncSem: semaphore.NewWeighted(1), + } +} + +// slowTestStore embeds testStore but overrides SourceList to block until +// the unblock channel is closed. This simulates a long-running sync holding +// the semaphore. +type slowTestStore struct { + *testStore + + unblock chan struct{} +} + +func (s *slowTestStore) SourceList(ctx context.Context) ([]string, error) { + select { + case <-s.unblock: + return s.testStore.SourceList(ctx) + case <-ctx.Done(): + return nil, ctx.Err() } } +// TestSynchronize_SyncRespectsContextCancellation verifies that a second +// Sync call returns promptly when its context expires while the first Sync +// holds the semaphore. With the old sync.Mutex this would block indefinitely; +// the semaphore.Weighted implementation respects context cancellation. +func TestSynchronize_SyncRespectsContextCancellation(t *testing.T) { + t.Parallel() + ctx := t.Context() + + slow := &slowTestStore{ + testStore: newTestStore([]string{"a"}, nil), + unblock: make(chan struct{}), + } + syncer := newSynchronizer(ctx, slow) + + // First sync: acquires the semaphore and blocks inside SourceList. + firstDone := make(chan error, 1) + go func() { + firstDone <- syncer.Sync(ctx) + }() + + // Give the first goroutine time to acquire the semaphore. + time.Sleep(20 * time.Millisecond) + + // Second sync: should fail fast when its context deadline expires. + shortCtx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer cancel() + + err := syncer.Sync(shortCtx) + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + + // Unblock the first sync and verify it completes successfully. + close(slow.unblock) + require.NoError(t, <-firstDone) +} + func TestSynchronize_InsertAndRemove(t *testing.T) { t.Parallel() ctx := t.Context() @@ -98,29 +158,15 @@ func TestSynchronize_InsertAndRemove(t *testing.T) { s := newTestStore([]string{"a", "b"}, nil) syncer := newSynchronizer(ctx, s) - if err := syncer.sync(ctx); err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if want, got := 2, s.inserts; want != got { - t.Fatalf("insert count mismatch: want %d got %d", want, got) - } - - if len(s.pool) != 2 { - t.Fatalf("pool size want 2 got %d", len(s.pool)) - } + require.NoError(t, syncer.Sync(ctx)) + assert.Equal(t, 2, s.inserts) + assert.Len(t, s.pool, 2) // Now remove "b" from the source – should trigger exactly one removal. s.source = []string{"a"} - if err := syncer.sync(ctx); err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.NoError(t, syncer.Sync(ctx)) - if want, got := 1, s.removes; want != got { - t.Fatalf("remove count mismatch: want %d got %d", want, got) - } - - if len(s.pool) != 1 || !s.PoolExists(ctx, "a") { - t.Fatalf("pool contents after removal are incorrect: %#v", s.pool) - } + assert.Equal(t, 1, s.removes) + assert.Len(t, s.pool, 1) + assert.True(t, s.PoolExists(ctx, "a")) }