From 8cee14faa99016722001a0ceea21f70af626ee24 Mon Sep 17 00:00:00 2001 From: Tanmay Sardesai Date: Wed, 11 Feb 2026 14:57:44 -0800 Subject: [PATCH] fix(scaletozero): skip scale-to-zero toggle for loopback connections Loopback (localhost) requests such as internal health checks were toggling the scale-to-zero control file, keeping VMs alive indefinitely. The middleware now detects loopback addresses and passes them through without touching scale-to-zero state. Also adds info-level logging around the control-file writes and covers the new behaviour with unit tests. Co-authored-by: Cursor --- server/cmd/api/api/process.go | 1 - server/lib/scaletozero/middleware.go | 22 +++++ server/lib/scaletozero/middleware_test.go | 114 ++++++++++++++++++++++ server/lib/scaletozero/scaletozero.go | 2 + 4 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 server/lib/scaletozero/middleware_test.go diff --git a/server/cmd/api/api/process.go b/server/cmd/api/api/process.go index 3dd6aeb4..8367637f 100644 --- a/server/cmd/api/api/process.go +++ b/server/cmd/api/api/process.go @@ -624,7 +624,6 @@ func (s *ApiService) ProcessResize(ctx context.Context, request oapi.ProcessResi return oapi.ProcessResize200JSONResponse(oapi.OkResponse{Ok: true}), nil } - // writeJSON writes a JSON response with the given status code. // Unlike http.Error, this sets the correct Content-Type for JSON. func writeJSON(w http.ResponseWriter, status int, body string) { diff --git a/server/lib/scaletozero/middleware.go b/server/lib/scaletozero/middleware.go index f67c06e6..181e43d0 100644 --- a/server/lib/scaletozero/middleware.go +++ b/server/lib/scaletozero/middleware.go @@ -2,6 +2,7 @@ package scaletozero import ( "context" + "net" "net/http" "github.com/onkernel/kernel-images/server/lib/logger" @@ -9,9 +10,16 @@ import ( // Middleware returns a standard net/http middleware that disables scale-to-zero // at the start of each request and re-enables it after the handler completes. +// Connections from loopback addresses are ignored and do not affect the +// scale-to-zero state. func Middleware(ctrl Controller) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isLoopbackAddr(r.RemoteAddr) { + next.ServeHTTP(w, r) + return + } + if err := ctrl.Disable(r.Context()); err != nil { logger.FromContext(r.Context()).Error("failed to disable scale-to-zero", "error", err) http.Error(w, "failed to disable scale-to-zero", http.StatusInternalServerError) @@ -23,3 +31,17 @@ func Middleware(ctrl Controller) func(http.Handler) http.Handler { }) } } + +// isLoopbackAddr reports whether addr is a loopback address. +// addr may be an "ip:port" pair or a bare IP. +func isLoopbackAddr(addr string) bool { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + ip := net.ParseIP(host) + if ip == nil { + return false + } + return ip.IsLoopback() +} diff --git a/server/lib/scaletozero/middleware_test.go b/server/lib/scaletozero/middleware_test.go new file mode 100644 index 00000000..c48b6122 --- /dev/null +++ b/server/lib/scaletozero/middleware_test.go @@ -0,0 +1,114 @@ +package scaletozero + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMiddlewareDisablesAndEnablesForExternalAddr(t *testing.T) { + t.Parallel() + mock := &mockScaleToZeroer{} + handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "203.0.113.50:12345" + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 1, mock.disableCalls) + assert.Equal(t, 1, mock.enableCalls) +} + +func TestMiddlewareSkipsLoopbackAddrs(t *testing.T) { + t.Parallel() + + loopbackAddrs := []struct { + name string + addr string + }{ + {"loopback-v4", "127.0.0.1:8080"}, + {"loopback-v6", "[::1]:8080"}, + } + + for _, tc := range loopbackAddrs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + mock := &mockScaleToZeroer{} + var called bool + handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = tc.addr + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.True(t, called, "handler should still be called") + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 0, mock.disableCalls, "should not disable for loopback addr") + assert.Equal(t, 0, mock.enableCalls, "should not enable for loopback addr") + }) + } +} + +func TestMiddlewareDisableError(t *testing.T) { + t.Parallel() + mock := &mockScaleToZeroer{disableErr: assert.AnError} + var called bool + handler := Middleware(mock)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "203.0.113.50:12345" + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.False(t, called, "handler should not be called on disable error") + assert.Equal(t, http.StatusInternalServerError, rec.Code) + assert.Equal(t, 0, mock.enableCalls) +} + +func TestIsLoopbackAddr(t *testing.T) { + t.Parallel() + + tests := []struct { + addr string + loopback bool + }{ + // Loopback + {"127.0.0.1:80", true}, + {"[::1]:80", true}, + {"127.0.0.1", true}, + {"::1", true}, + // Non-loopback + {"10.0.0.1:80", false}, + {"172.16.0.1:80", false}, + {"192.168.1.1:80", false}, + {"203.0.113.50:80", false}, + {"8.8.8.8:53", false}, + {"[2001:db8::1]:80", false}, + // Unparseable + {"not-an-ip:80", false}, + {"", false}, + } + + for _, tc := range tests { + t.Run(tc.addr, func(t *testing.T) { + t.Parallel() + require.Equal(t, tc.loopback, isLoopbackAddr(tc.addr)) + }) + } +} diff --git a/server/lib/scaletozero/scaletozero.go b/server/lib/scaletozero/scaletozero.go index 96b67281..c37fcba3 100644 --- a/server/lib/scaletozero/scaletozero.go +++ b/server/lib/scaletozero/scaletozero.go @@ -38,6 +38,7 @@ func (c *unikraftCloudController) Enable(ctx context.Context) error { func (c *unikraftCloudController) write(ctx context.Context, char string) error { if _, err := os.Stat(c.path); err != nil { if os.IsNotExist(err) { + logger.FromContext(ctx).Info("scale-to-zero control file not found, skipping write", "path", c.path, "value", char) return nil } logger.FromContext(ctx).Error("failed to stat scale-to-zero control file", "path", c.path, "err", err) @@ -54,6 +55,7 @@ func (c *unikraftCloudController) write(ctx context.Context, char string) error logger.FromContext(ctx).Error("failed to write scale-to-zero control file", "path", c.path, "err", err) return err } + logger.FromContext(ctx).Info("scale-to-zero control file written", "path", c.path, "value", char) return nil }