Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions server/cmd/api/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ type ApiService struct {
upstreamMgr *devtoolsproxy.UpstreamManager
stz scaletozero.Controller

// CDP event pipeline and cdpMonitor.
captureSession *events.CaptureSession
cdpMonitor *cdpmonitor.Monitor
monitorMu sync.Mutex

// inputMu serializes input-related operations (mouse, keyboard, screenshot)
inputMu sync.Mutex

Expand Down Expand Up @@ -70,11 +75,6 @@ type ApiService struct {
// xvfbResizeMu serializes background Xvfb restarts to prevent races
// when multiple CDP fast-path resizes fire in quick succession.
xvfbResizeMu sync.Mutex

// CDP event pipeline and cdpMonitor.
captureSession *events.CaptureSession
cdpMonitor *cdpmonitor.Monitor
monitorMu sync.Mutex
}

var _ oapi.StrictServerInterface = (*ApiService)(nil)
Expand All @@ -101,8 +101,6 @@ func New(
return nil, fmt.Errorf("captureSession cannot be nil")
}

mon := cdpmonitor.New(upstreamMgr, captureSession.Publish, displayNum)

return &ApiService{
recordManager: recordManager,
factory: factory,
Expand All @@ -114,7 +112,7 @@ func New(
nekoAuthClient: nekoAuthClient,
policy: &policy.Policy{},
captureSession: captureSession,
cdpMonitor: mon,
cdpMonitor: cdpmonitor.New(upstreamMgr, captureSession.Publish, displayNum),
}, nil
}

Expand Down Expand Up @@ -335,8 +333,14 @@ func (s *ApiService) ListRecorders(ctx context.Context, _ oapi.ListRecordersRequ

func (s *ApiService) Shutdown(ctx context.Context) error {
s.monitorMu.Lock()
s.cdpMonitor.Stop()
_ = s.captureSession.Close()
if s.cdpMonitor != nil {
s.cdpMonitor.Stop()
}
if s.captureSession != nil {
if err := s.captureSession.Close(); err != nil {
logger.FromContext(ctx).Error("failed to close capture session", "err", err)
}
}
s.monitorMu.Unlock()
return s.recordManager.StopAll(ctx)
}
91 changes: 85 additions & 6 deletions server/cmd/api/api/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@ package api

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"

"github.com/google/uuid"
"github.com/onkernel/kernel-images/server/lib/events"
"github.com/onkernel/kernel-images/server/lib/logger"
)

// StartCapture handles POST /events/start.
// Generates a new capture session ID, seeds the pipeline, then starts the
// CDP monitor. If already running, the monitor is stopped and
// restarted with a fresh session ID
// StartCapture handles POST /events/start. Restarts if already running.
func (s *ApiService) StartCapture(w http.ResponseWriter, r *http.Request) {
s.monitorMu.Lock()
defer s.monitorMu.Unlock()

s.captureSession.Start(uuid.New().String())
captureSessionID := uuid.New().String()
s.captureSession.Start(captureSessionID)

if err := s.cdpMonitor.Start(context.Background()); err != nil {
logger.FromContext(r.Context()).Error("failed to start CDP monitor", "err", err)
Expand All @@ -26,10 +29,86 @@ func (s *ApiService) StartCapture(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}

// StopCapture handles POST /events/stop
// StopCapture handles POST /events/stop. No-op if not running.
func (s *ApiService) StopCapture(w http.ResponseWriter, r *http.Request) {
s.monitorMu.Lock()
defer s.monitorMu.Unlock()
s.cdpMonitor.Stop()
w.WriteHeader(http.StatusOK)
}

// PublishEvent handles POST /events/publish.
// Defaults Category (via CategoryFor) and Source.Kind (to KindKernelAPI) when omitted.
func (s *ApiService) PublishEvent(w http.ResponseWriter, r *http.Request) {
var ev events.Event
if err := json.NewDecoder(r.Body).Decode(&ev); err != nil {
http.Error(w, "invalid JSON body", http.StatusBadRequest)
return
}

if ev.Type == "" {
http.Error(w, "type is required", http.StatusBadRequest)
return
}

if ev.Category == "" {
ev.Category = events.CategoryFor(ev.Type)
}

if ev.Source.Kind == "" {
ev.Source.Kind = events.KindKernelAPI
}

s.captureSession.Publish(ev)
w.WriteHeader(http.StatusOK)
}

// StreamEvents handles GET /events/stream (SSE).
// Supports Last-Event-ID for reconnection.
func (s *ApiService) StreamEvents(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming not supported", http.StatusInternalServerError)
return
}

var lastSeq uint64
if v := r.Header.Get("Last-Event-ID"); v != "" {
if n, err := strconv.ParseUint(v, 10, 64); err == nil {
lastSeq = n
}
}

w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("X-Accel-Buffering", "no")
w.WriteHeader(http.StatusOK)
flusher.Flush()

reader := s.captureSession.NewReader(lastSeq)
ctx := r.Context()

for {
res, err := reader.Read(ctx)
if err != nil {
return
}
if res.Envelope == nil {
continue
}
if err := writeSSEEnvelope(w, *res.Envelope); err != nil {
return
}
flusher.Flush()
}
}

// writeSSEEnvelope writes a single SSE frame: "id: {seq}\ndata: {json}\n\n".
func writeSSEEnvelope(w io.Writer, env events.Envelope) error {
data, err := json.Marshal(env)
if err != nil {
return err
}
_, err = fmt.Fprintf(w, "id: %d\ndata: %s\n\n", env.Seq, data)
return err
}
150 changes: 150 additions & 0 deletions server/cmd/api/api/events_publish_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package api

import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"

"github.com/onkernel/kernel-images/server/lib/events"
"github.com/onkernel/kernel-images/server/lib/recorder"
"github.com/onkernel/kernel-images/server/lib/scaletozero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func newPublishTestService(t *testing.T, logDir string) (*ApiService, *events.CaptureSession) {
t.Helper()
ring := events.NewRingBuffer(16)
fw := events.NewFileWriter(logDir)
cs := events.NewCaptureSession(ring, fw)
cs.Start("test-capture")
svc, err := New(
recorder.NewFFmpegManager(),
newMockFactory(),
newTestUpstreamManager(),
scaletozero.NewNoopController(),
newMockNekoClient(t),
cs,
0,
)
require.NoError(t, err)
return svc, cs
}

func publishEvent(t *testing.T, svc *ApiService, ev events.Event) *httptest.ResponseRecorder {
t.Helper()
b, err := json.Marshal(ev)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, "/events/publish", bytes.NewReader(b))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
svc.PublishEvent(w, req)
return w
}

func readEnvelope(t *testing.T, cs *events.CaptureSession) events.Envelope {
t.Helper()
reader := cs.NewReader(0)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
res, err := reader.Read(ctx)
require.NoError(t, err)
require.NotNil(t, res.Envelope)
return *res.Envelope
}

func assertLogFileExists(t *testing.T, logDir, filename string) {
t.Helper()
info, err := os.Stat(filepath.Join(logDir, filename))
require.NoError(t, err, "%s should exist", filename)
assert.Greater(t, info.Size(), int64(0), "%s should be non-empty", filename)
}

func TestPublishEvent(t *testing.T) {
t.Run("valid_event_published_to_ring", func(t *testing.T) {
logDir := t.TempDir()
svc, cs := newPublishTestService(t, logDir)

w := publishEvent(t, svc, events.Event{
Type: "liveview_click",
Category: events.CategoryLiveview,
Source: events.Source{Kind: events.KindKernelAPI},
Data: json.RawMessage(`{"x":100}`),
})
assert.Equal(t, http.StatusOK, w.Code)

env := readEnvelope(t, cs)
assert.Equal(t, "liveview_click", env.Event.Type)
assert.Equal(t, events.CategoryLiveview, env.Event.Category)
})

t.Run("invalid_json", func(t *testing.T) {
svc, _ := newPublishTestService(t, t.TempDir())

req := httptest.NewRequest(http.MethodPost, "/events/publish", bytes.NewReader([]byte(`not-json`)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
svc.PublishEvent(w, req)

assert.Equal(t, http.StatusBadRequest, w.Code)
})

t.Run("empty_type_rejected", func(t *testing.T) {
svc, _ := newPublishTestService(t, t.TempDir())

w := publishEvent(t, svc, events.Event{
Category: events.CategoryConsole,
Data: json.RawMessage(`{"x":1}`),
})
assert.Equal(t, http.StatusBadRequest, w.Code)
})

t.Run("liveview_routes_to_log", func(t *testing.T) {
logDir := t.TempDir()
svc, _ := newPublishTestService(t, logDir)

w := publishEvent(t, svc, events.Event{
Type: "liveview_click",
Category: events.CategoryLiveview,
Source: events.Source{Kind: events.KindKernelAPI},
Data: json.RawMessage(`{"x":100}`),
})
require.Equal(t, http.StatusOK, w.Code)
assertLogFileExists(t, logDir, "liveview.log")
})

t.Run("captcha_routes_to_log", func(t *testing.T) {
logDir := t.TempDir()
svc, _ := newPublishTestService(t, logDir)

w := publishEvent(t, svc, events.Event{
Type: "captcha_solve",
Category: events.CategoryCaptcha,
Source: events.Source{Kind: events.KindKernelAPI},
Data: json.RawMessage(`{"token":"abc"}`),
})
require.Equal(t, http.StatusOK, w.Code)
assertLogFileExists(t, logDir, "captcha.log")
})

t.Run("category_derived_from_type", func(t *testing.T) {
logDir := t.TempDir()
svc, cs := newPublishTestService(t, logDir)

w := publishEvent(t, svc, events.Event{
Type: "liveview_click",
Data: json.RawMessage(`{"x":50}`),
})
require.Equal(t, http.StatusOK, w.Code)

env := readEnvelope(t, cs)
assert.Equal(t, events.CategoryLiveview, env.Event.Category)
assertLogFileExists(t, logDir, "liveview.log")
})
}
Loading