diff --git a/server/cmd/api/api/api.go b/server/cmd/api/api/api.go index 3523904d..57689c80 100644 --- a/server/cmd/api/api/api.go +++ b/server/cmd/api/api/api.go @@ -41,6 +41,11 @@ type ApiService struct { upstreamMgr *devtoolsproxy.UpstreamManager stz scaletozero.Controller + // CDP event pipeline and cdpMonitor. + captureSession *events.CaptureSession + cdpMonitor *cdpmonitor.Monitor + monitorMu sync.Mutex + // inputMu serializes input-related operations (mouse, keyboard, screenshot) inputMu sync.Mutex @@ -70,11 +75,6 @@ type ApiService struct { // xvfbResizeMu serializes background Xvfb restarts to prevent races // when multiple CDP fast-path resizes fire in quick succession. xvfbResizeMu sync.Mutex - - // CDP event pipeline and cdpMonitor. - captureSession *events.CaptureSession - cdpMonitor *cdpmonitor.Monitor - monitorMu sync.Mutex } var _ oapi.StrictServerInterface = (*ApiService)(nil) @@ -101,8 +101,6 @@ func New( return nil, fmt.Errorf("captureSession cannot be nil") } - mon := cdpmonitor.New(upstreamMgr, captureSession.Publish, displayNum) - return &ApiService{ recordManager: recordManager, factory: factory, @@ -114,7 +112,7 @@ func New( nekoAuthClient: nekoAuthClient, policy: &policy.Policy{}, captureSession: captureSession, - cdpMonitor: mon, + cdpMonitor: cdpmonitor.New(upstreamMgr, captureSession.Publish, displayNum), }, nil } @@ -335,8 +333,14 @@ func (s *ApiService) ListRecorders(ctx context.Context, _ oapi.ListRecordersRequ func (s *ApiService) Shutdown(ctx context.Context) error { s.monitorMu.Lock() - s.cdpMonitor.Stop() - _ = s.captureSession.Close() + if s.cdpMonitor != nil { + s.cdpMonitor.Stop() + } + if s.captureSession != nil { + if err := s.captureSession.Close(); err != nil { + logger.FromContext(ctx).Error("failed to close capture session", "err", err) + } + } s.monitorMu.Unlock() return s.recordManager.StopAll(ctx) } diff --git a/server/cmd/api/api/events.go b/server/cmd/api/api/events.go index f9021a17..9c2435f8 100644 --- a/server/cmd/api/api/events.go +++ b/server/cmd/api/api/events.go @@ -2,21 +2,24 @@ package api import ( "context" + "encoding/json" + "fmt" + "io" "net/http" + "strconv" "github.com/google/uuid" + "github.com/onkernel/kernel-images/server/lib/events" "github.com/onkernel/kernel-images/server/lib/logger" ) -// StartCapture handles POST /events/start. -// Generates a new capture session ID, seeds the pipeline, then starts the -// CDP monitor. If already running, the monitor is stopped and -// restarted with a fresh session ID +// StartCapture handles POST /events/start. Restarts if already running. func (s *ApiService) StartCapture(w http.ResponseWriter, r *http.Request) { s.monitorMu.Lock() defer s.monitorMu.Unlock() - s.captureSession.Start(uuid.New().String()) + captureSessionID := uuid.New().String() + s.captureSession.Start(captureSessionID) if err := s.cdpMonitor.Start(context.Background()); err != nil { logger.FromContext(r.Context()).Error("failed to start CDP monitor", "err", err) @@ -26,10 +29,86 @@ func (s *ApiService) StartCapture(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -// StopCapture handles POST /events/stop +// StopCapture handles POST /events/stop. No-op if not running. func (s *ApiService) StopCapture(w http.ResponseWriter, r *http.Request) { s.monitorMu.Lock() defer s.monitorMu.Unlock() s.cdpMonitor.Stop() w.WriteHeader(http.StatusOK) } + +// PublishEvent handles POST /events/publish. +// Defaults Category (via CategoryFor) and Source.Kind (to KindKernelAPI) when omitted. +func (s *ApiService) PublishEvent(w http.ResponseWriter, r *http.Request) { + var ev events.Event + if err := json.NewDecoder(r.Body).Decode(&ev); err != nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) + return + } + + if ev.Type == "" { + http.Error(w, "type is required", http.StatusBadRequest) + return + } + + if ev.Category == "" { + ev.Category = events.CategoryFor(ev.Type) + } + + if ev.Source.Kind == "" { + ev.Source.Kind = events.KindKernelAPI + } + + s.captureSession.Publish(ev) + w.WriteHeader(http.StatusOK) +} + +// StreamEvents handles GET /events/stream (SSE). +// Supports Last-Event-ID for reconnection. +func (s *ApiService) StreamEvents(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming not supported", http.StatusInternalServerError) + return + } + + var lastSeq uint64 + if v := r.Header.Get("Last-Event-ID"); v != "" { + if n, err := strconv.ParseUint(v, 10, 64); err == nil { + lastSeq = n + } + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("X-Accel-Buffering", "no") + w.WriteHeader(http.StatusOK) + flusher.Flush() + + reader := s.captureSession.NewReader(lastSeq) + ctx := r.Context() + + for { + res, err := reader.Read(ctx) + if err != nil { + return + } + if res.Envelope == nil { + continue + } + if err := writeSSEEnvelope(w, *res.Envelope); err != nil { + return + } + flusher.Flush() + } +} + +// writeSSEEnvelope writes a single SSE frame: "id: {seq}\ndata: {json}\n\n". +func writeSSEEnvelope(w io.Writer, env events.Envelope) error { + data, err := json.Marshal(env) + if err != nil { + return err + } + _, err = fmt.Fprintf(w, "id: %d\ndata: %s\n\n", env.Seq, data) + return err +} diff --git a/server/cmd/api/api/events_publish_test.go b/server/cmd/api/api/events_publish_test.go new file mode 100644 index 00000000..4e1e4099 --- /dev/null +++ b/server/cmd/api/api/events_publish_test.go @@ -0,0 +1,150 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/onkernel/kernel-images/server/lib/events" + "github.com/onkernel/kernel-images/server/lib/recorder" + "github.com/onkernel/kernel-images/server/lib/scaletozero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newPublishTestService(t *testing.T, logDir string) (*ApiService, *events.CaptureSession) { + t.Helper() + ring := events.NewRingBuffer(16) + fw := events.NewFileWriter(logDir) + cs := events.NewCaptureSession(ring, fw) + cs.Start("test-capture") + svc, err := New( + recorder.NewFFmpegManager(), + newMockFactory(), + newTestUpstreamManager(), + scaletozero.NewNoopController(), + newMockNekoClient(t), + cs, + 0, + ) + require.NoError(t, err) + return svc, cs +} + +func publishEvent(t *testing.T, svc *ApiService, ev events.Event) *httptest.ResponseRecorder { + t.Helper() + b, err := json.Marshal(ev) + require.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/events/publish", bytes.NewReader(b)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + svc.PublishEvent(w, req) + return w +} + +func readEnvelope(t *testing.T, cs *events.CaptureSession) events.Envelope { + t.Helper() + reader := cs.NewReader(0) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + res, err := reader.Read(ctx) + require.NoError(t, err) + require.NotNil(t, res.Envelope) + return *res.Envelope +} + +func assertLogFileExists(t *testing.T, logDir, filename string) { + t.Helper() + info, err := os.Stat(filepath.Join(logDir, filename)) + require.NoError(t, err, "%s should exist", filename) + assert.Greater(t, info.Size(), int64(0), "%s should be non-empty", filename) +} + +func TestPublishEvent(t *testing.T) { + t.Run("valid_event_published_to_ring", func(t *testing.T) { + logDir := t.TempDir() + svc, cs := newPublishTestService(t, logDir) + + w := publishEvent(t, svc, events.Event{ + Type: "liveview_click", + Category: events.CategoryLiveview, + Source: events.Source{Kind: events.KindKernelAPI}, + Data: json.RawMessage(`{"x":100}`), + }) + assert.Equal(t, http.StatusOK, w.Code) + + env := readEnvelope(t, cs) + assert.Equal(t, "liveview_click", env.Event.Type) + assert.Equal(t, events.CategoryLiveview, env.Event.Category) + }) + + t.Run("invalid_json", func(t *testing.T) { + svc, _ := newPublishTestService(t, t.TempDir()) + + req := httptest.NewRequest(http.MethodPost, "/events/publish", bytes.NewReader([]byte(`not-json`))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + svc.PublishEvent(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + + t.Run("empty_type_rejected", func(t *testing.T) { + svc, _ := newPublishTestService(t, t.TempDir()) + + w := publishEvent(t, svc, events.Event{ + Category: events.CategoryConsole, + Data: json.RawMessage(`{"x":1}`), + }) + assert.Equal(t, http.StatusBadRequest, w.Code) + }) + + t.Run("liveview_routes_to_log", func(t *testing.T) { + logDir := t.TempDir() + svc, _ := newPublishTestService(t, logDir) + + w := publishEvent(t, svc, events.Event{ + Type: "liveview_click", + Category: events.CategoryLiveview, + Source: events.Source{Kind: events.KindKernelAPI}, + Data: json.RawMessage(`{"x":100}`), + }) + require.Equal(t, http.StatusOK, w.Code) + assertLogFileExists(t, logDir, "liveview.log") + }) + + t.Run("captcha_routes_to_log", func(t *testing.T) { + logDir := t.TempDir() + svc, _ := newPublishTestService(t, logDir) + + w := publishEvent(t, svc, events.Event{ + Type: "captcha_solve", + Category: events.CategoryCaptcha, + Source: events.Source{Kind: events.KindKernelAPI}, + Data: json.RawMessage(`{"token":"abc"}`), + }) + require.Equal(t, http.StatusOK, w.Code) + assertLogFileExists(t, logDir, "captcha.log") + }) + + t.Run("category_derived_from_type", func(t *testing.T) { + logDir := t.TempDir() + svc, cs := newPublishTestService(t, logDir) + + w := publishEvent(t, svc, events.Event{ + Type: "liveview_click", + Data: json.RawMessage(`{"x":50}`), + }) + require.Equal(t, http.StatusOK, w.Code) + + env := readEnvelope(t, cs) + assert.Equal(t, events.CategoryLiveview, env.Event.Category) + assertLogFileExists(t, logDir, "liveview.log") + }) +} diff --git a/server/cmd/api/api/events_stream_test.go b/server/cmd/api/api/events_stream_test.go new file mode 100644 index 00000000..fe4e893e --- /dev/null +++ b/server/cmd/api/api/events_stream_test.go @@ -0,0 +1,105 @@ +package api + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/onkernel/kernel-images/server/lib/events" + "github.com/stretchr/testify/assert" +) + +var testEvent = events.Event{ + Type: "console_log", + Category: events.CategoryConsole, + Source: events.Source{Kind: events.KindCDP}, +} + +func streamRequest(ctx context.Context) (*httptest.ResponseRecorder, *http.Request) { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/events/stream", nil).WithContext(ctx) + return w, req +} + +func TestStreamEvents(t *testing.T) { + t.Run("delivers_buffered_events", func(t *testing.T) { + svc, cs := newPublishTestService(t, t.TempDir()) + cs.Publish(testEvent) + cs.Publish(testEvent) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + w, req := streamRequest(ctx) + + svc.StreamEvents(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "text/event-stream", w.Header().Get("Content-Type")) + body := w.Body.String() + assert.Contains(t, body, "id: 1") + assert.Contains(t, body, "id: 2") + }) + + t.Run("resumes_after_last_event_id", func(t *testing.T) { + svc, cs := newPublishTestService(t, t.TempDir()) + cs.Publish(testEvent) + cs.Publish(testEvent) + cs.Publish(testEvent) + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + w, req := streamRequest(ctx) + req.Header.Set("Last-Event-ID", "2") + + svc.StreamEvents(w, req) + + body := w.Body.String() + assert.Contains(t, body, "id: 3") + assert.NotContains(t, body, "id: 1") + assert.NotContains(t, body, "id: 2") + }) + + t.Run("exits_on_cancelled_context", func(t *testing.T) { + svc, _ := newPublishTestService(t, t.TempDir()) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + w, req := streamRequest(ctx) + + done := make(chan struct{}) + go func() { + svc.StreamEvents(w, req) + close(done) + }() + + select { + case <-done: + case <-time.After(100 * time.Millisecond): + t.Error("StreamEvents did not return after context cancellation") + } + }) + + t.Run("rejects_non_flusher", func(t *testing.T) { + svc, _ := newPublishTestService(t, t.TempDir()) + + req := httptest.NewRequest(http.MethodGet, "/events/stream", nil) + w := &nonFlusherWriter{header: make(http.Header)} + + svc.StreamEvents(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.code) + }) +} + +type nonFlusherWriter struct { + header http.Header + code int + body strings.Builder +} + +func (w *nonFlusherWriter) Header() http.Header { return w.header } +func (w *nonFlusherWriter) WriteHeader(code int) { w.code = code } +func (w *nonFlusherWriter) Write(b []byte) (int, error) { return w.body.Write(b) } diff --git a/server/cmd/api/main.go b/server/cmd/api/main.go index 767c4881..50b5636b 100644 --- a/server/cmd/api/main.go +++ b/server/cmd/api/main.go @@ -128,10 +128,6 @@ func main() { w.Header().Set("Content-Type", "application/json") w.Write(jsonData) }) - // capture events - r.Post("/events/start", apiService.StartCapture) - r.Post("/events/stop", apiService.StopCapture) - // PTY attach endpoint (WebSocket) - not part of OpenAPI spec // Uses WebSocket for bidirectional streaming, which works well through proxies. r.Get("/process/{process_id}/attach", func(w http.ResponseWriter, r *http.Request) { @@ -139,6 +135,12 @@ func main() { apiService.HandleProcessAttachWS(w, r, id) }) + // Events capture lifecycle (not part of OpenAPI spec — simple internal control endpoints) + r.Post("/events/start", apiService.StartCapture) + r.Post("/events/stop", apiService.StopCapture) + r.Post("/events/publish", apiService.PublishEvent) + r.Get("/events/stream", apiService.StreamEvents) + // Serve extension files for Chrome policy-installed extensions // This allows Chrome to download .crx and update.xml files via HTTP extensionsDir := "/home/kernel/extensions"