Skip to content
2 changes: 1 addition & 1 deletion .github/workflows/kind-e2e.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ jobs:
run: |
gotestsum --format testname -- \
-race -count=1 -parallel=1 -tags=e2e \
-timeout=30m \
-timeout=35m \
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the extra test we need to bump this since we run things with -parallel=1 due to github actions default runners being crappy.

Copy link
Copy Markdown
Member Author

@dprotaso dprotaso Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I adjusted the test to not use time.Sleep to speed it up

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not that relevant for this PR but timing tests could also be sped up using synctest after requiring go 1.25, see https://go.dev/blog/testing-time

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for e2e. It would run faster if we ran tests in parallel - but with kind+github action runners they are so under provisioned that tests flake out etc.

${{ matrix.test-path }} \
-skip-cleanup-on-fail \
-disable-logstream \
Expand Down
66 changes: 66 additions & 0 deletions pkg/http/handler/hijack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
Copyright 2026 The Knative Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package handler

import (
"cmp"
"context"
"net/http"
"sync/atomic"
"time"
)

// HijackTracker is used to track Websocket Connections
// Go net/http by default will not manage connections that
// are hijacked. Thus http.Server::Shutdown will not wait
// for those connections to finish.
//
// What this handler does is track inflight requests
// using a counter and drain will loop and poll until
// all the requests are finished.
type HijackTracker struct {
Handler http.Handler
PollInterval time.Duration

inflight atomic.Int64
}

// Drain should be called after http.Server:Shutdown returns
func (s *HijackTracker) Drain(ctx context.Context) error {
pollInterval := cmp.Or(s.PollInterval, time.Second)

ticker := time.NewTicker(pollInterval)
defer ticker.Stop()

for {
if s.inflight.Load() == 0 {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}

func (s *HijackTracker) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.inflight.Add(1)
defer s.inflight.Add(-1)
Comment thread
dprotaso marked this conversation as resolved.

s.Handler.ServeHTTP(w, r)
}
139 changes: 139 additions & 0 deletions pkg/http/handler/hijack_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
Copyright 2026 The Knative Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package handler

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
)

func TestHijackTrackerNoHijack(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "http://somehost.com", nil)

h := &HijackTracker{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
}
h.ServeHTTP(w, r)

err := h.Drain(context.Background())
if err != nil {
t.Fatal("unexpected error while draining", err)
}
}

func TestHijackTrackerConnectionHijacked(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "http://somehost.com", nil)

inHandler := make(chan struct{})
handlerWait := make(chan struct{})
drainResult := make(chan error, 1)

h := &HijackTracker{
PollInterval: 10 * time.Millisecond,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(inHandler)
<-handlerWait
}),
}

go func() {
h.ServeHTTP(w, r)
}()

select {
case <-inHandler:
case <-time.After(250 * time.Millisecond):
t.Fatal("control flow never reached the http handler")
}

go func() {
drainResult <- h.Drain(context.Background())
}()

select {
case <-time.After(250 * time.Millisecond):
case <-drainResult:
t.Fatal("drain returned befoce handler was finished")
}

close(handlerWait)

var err error
select {
case <-time.After(1 * time.Second):
t.Fatal("Drain was not unblocked when the handler returned")
case err = <-drainResult:
}

if err != nil {
t.Fatal("unexpected error draining", err)
}
}

func TestHijackTrackerConnectionHijackedTimeout(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "http://somehost.com", nil)

inHandler := make(chan struct{})
handlerWait := make(chan struct{})
drainStarted := make(chan struct{})
drainResult := make(chan error, 1)

h := &HijackTracker{
PollInterval: 10 * time.Millisecond,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(inHandler)
<-handlerWait
}),
}

go func() {
h.ServeHTTP(w, r)
}()

go func() {
<-inHandler
close(drainStarted)
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Millisecond)
defer cancel()
drainResult <- h.Drain(ctx)
}()

<-drainStarted
// note: this is defered to unblock the go-routine
// to clean up the test
defer close(handlerWait)

var err error
select {
case <-time.After(1 * time.Second):
t.Fatal("Drain did not timeout")
case err = <-drainResult:
}

if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("unexpected error draining", err)
}
}
26 changes: 18 additions & 8 deletions pkg/queue/sharedmain/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ import (
"knative.dev/serving/pkg/queue/health"
)

type drainers struct {
HijackedDrainer *handler.HijackTracker
StandardDrainer *pkghandler.Drainer
}

func mainHandler(
env config,
d Defaults,
Expand All @@ -42,9 +47,9 @@ func mainHandler(
logger *zap.SugaredLogger,
mp metric.MeterProvider,
tp trace.TracerProvider,
) (http.Handler, *pkghandler.Drainer) {
) (http.Handler, drainers) {
var drainers drainers
tracer := tp.Tracer("knative.dev/serving/pkg/queue")

breaker := buildBreaker(logger, env)

timeout := time.Duration(env.RevisionTimeoutSeconds) * time.Second
Expand All @@ -63,21 +68,26 @@ func mainHandler(
composedHandler = requestAppMetricsHandler(logger, composedHandler, breaker, mp)
composedHandler = queue.ProxyHandler(tracer, breaker, stats, composedHandler)
composedHandler = queue.ForwardedShimHandler(composedHandler)
composedHandler = handler.NewTimeoutHandler(composedHandler, "request timeout", func(r *http.Request) (time.Duration, time.Duration, time.Duration) {
return timeout, responseStartTimeout, idleTimeout
}, logger)
composedHandler = handler.NewTimeoutHandler(composedHandler, "request timeout",
func(r *http.Request) (time.Duration, time.Duration, time.Duration) {
return timeout, responseStartTimeout, idleTimeout
}, logger)

composedHandler = queue.NewRouteTagHandler(composedHandler)
composedHandler = withFullDuplex(composedHandler, env.EnableHTTPFullDuplex, logger)

drainer := &pkghandler.Drainer{
drainers.HijackedDrainer = &handler.HijackTracker{Handler: composedHandler}
composedHandler = drainers.HijackedDrainer

drainers.StandardDrainer = &pkghandler.Drainer{
QuietPeriod: drainSleepDuration,
// Add Activator probe header to the drainer so it can handle probes directly from activator
HealthCheckUAPrefixes: []string{netheader.ActivatorUserAgent, netheader.AutoscalingUserAgent},
Inner: composedHandler,
HealthCheck: health.ProbeHandler(tracer, prober),
}
composedHandler = drainer

composedHandler = drainers.StandardDrainer

if env.Observability.EnableRequestLog {
// We want to capture the probes/healthchecks in the request logs.
Expand All @@ -95,7 +105,7 @@ func mainHandler(
}),
)

return composedHandler, drainer
return composedHandler, drainers
}

func adminHandler(ctx context.Context, logger *zap.SugaredLogger, drainer *pkghandler.Drainer) http.Handler {
Expand Down
24 changes: 19 additions & 5 deletions pkg/queue/sharedmain/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ const (
// from its configuration and propagate that to all loadbalancers and nodes.
drainSleepDuration = 30 * time.Second

// hijack drain duration is the amount of extra time we wait for
// hijacked connections to drain
hijackDrainDuration = 60 * time.Second

// certPath is the path for the server certificate mounted by queue-proxy.
certPath = queue.CertDirectory + "/" + certificates.CertName

Expand Down Expand Up @@ -242,8 +246,8 @@ func Main(opts ...Option) error {
// Enable TLS when certificate is mounted.
tlsEnabled := exists(logger, certPath) && exists(logger, keyPath)

mainHandler, drainer := mainHandler(env, d, probe, stats, logger, mp, tp)
adminHandler := adminHandler(d.Ctx, logger, drainer)
mainHandler, drainers := mainHandler(env, d, probe, stats, logger, mp, tp)
adminHandler := adminHandler(d.Ctx, logger, drainers.StandardDrainer)

// Enable TLS server when activator server certs are mounted.
// At this moment activator with TLS does not disable HTTP.
Expand Down Expand Up @@ -312,21 +316,31 @@ func Main(opts ...Option) error {
case <-d.Ctx.Done():
logger.Info("Received TERM signal, attempting to gracefully shutdown servers.")
logger.Infof("Sleeping %v to allow K8s propagation of non-ready state", drainSleepDuration)
drainer.Drain()
drainers.StandardDrainer.Drain()

ctx := context.Background()

for name, srv := range httpServers {
logger.Info("Shutting down server: ", name)
if err := srv.Shutdown(context.Background()); err != nil {
if err := srv.Shutdown(ctx); err != nil {
logger.Errorw("Failed to shutdown server", zap.String("server", name), zap.Error(err))
}
}
for name, srv := range tlsServers {
logger.Info("Shutting down server: ", name)
if err := srv.Shutdown(context.Background()); err != nil {
if err := srv.Shutdown(ctx); err != nil {
logger.Errorw("Failed to shutdown server", zap.String("server", name), zap.Error(err))
}
}

// Limit hijack draining to 60s for now
ctx, cancel := context.WithTimeout(ctx, hijackDrainDuration)
defer cancel()

if err := drainers.HijackedDrainer.Drain(ctx); err != nil {
logger.Warnw("Hijack connection drain failed", zap.Error(err))
}

logger.Info("Shutdown complete, exiting...")
}
return nil
Expand Down
Loading
Loading