From c5e122a13076a94acedc78dd81a3107a5da4f995 Mon Sep 17 00:00:00 2001 From: Marcus Messer Date: Wed, 6 May 2026 11:01:18 +0100 Subject: [PATCH 1/4] Added `MuEdHandler` to handle `/evaluate` and `/evaluate/health` endpoints with authentication and runtime integration, along with associated tests --- handler/module.go | 3 + handler/mued.go | 172 +++++++++++++++++++++++ handler/mued_test.go | 325 +++++++++++++++++++++++++++++++++++++++++++ handler/routes.go | 8 ++ runtime/mued.go | 161 +++++++++++++++++++++ runtime/mued_test.go | 299 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 968 insertions(+) create mode 100644 handler/mued.go create mode 100644 handler/mued_test.go create mode 100644 runtime/mued.go create mode 100644 runtime/mued_test.go diff --git a/handler/module.go b/handler/module.go index ac7b0da..cea7455 100644 --- a/handler/module.go +++ b/handler/module.go @@ -5,8 +5,11 @@ import "go.uber.org/fx" func Module() fx.Option { return fx.Module("common", fx.Provide(NewCommandHandler), + fx.Provide(NewMuEdHandler), fx.Provide(NewLegacyRoute), fx.Provide(NewCommandRoute), fx.Provide(NewHealthRoute), + fx.Provide(NewMuEdEvaluateRoute), + fx.Provide(NewMuEdHealthRoute), ) } diff --git a/handler/mued.go b/handler/mued.go new file mode 100644 index 0000000..8bfbb95 --- /dev/null +++ b/handler/mued.go @@ -0,0 +1,172 @@ +package handler + +import ( + "encoding/json" + "io" + "net/http" + + "go.uber.org/fx" + "go.uber.org/zap" + + "github.com/lambda-feedback/shimmy/config" + "github.com/lambda-feedback/shimmy/runtime" +) + +type MuEdHandlerParams struct { + fx.In + + Handler runtime.Handler + Runtime runtime.Runtime + Config config.Config + Log *zap.Logger +} + +type MuEdHandler struct { + handler runtime.Handler + runtime runtime.Runtime + config config.Config + log *zap.Logger +} + +func NewMuEdHandler(params MuEdHandlerParams) *MuEdHandler { + return &MuEdHandler{ + handler: params.Handler, + runtime: params.Runtime, + config: params.Config, + log: params.Log, + } +} + +func (h *MuEdHandler) checkAuth(w http.ResponseWriter, r *http.Request) bool { + if h.config.Auth.Key != "" && r.Header.Get("api-key") != h.config.Auth.Key { + h.log.Debug("unauthorized request", zap.String("path", r.URL.Path)) + http.Error(w, "unauthorized", http.StatusUnauthorized) + return false + } + return true +} + +// ServeEvaluate handles POST /evaluate. +func (h *MuEdHandler) ServeEvaluate(w http.ResponseWriter, r *http.Request) { + if !h.checkAuth(w, r) { + return + } + + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "failed to read body", http.StatusBadRequest) + return + } + + var muEdReq runtime.MuEdEvaluateRequest + if err := json.Unmarshal(body, &muEdReq); err != nil { + http.Error(w, "invalid request body", http.StatusBadRequest) + return + } + + isPreview := muEdReq.PreSubmissionFeedback != nil && muEdReq.PreSubmissionFeedback.Enabled + + var legacyBody map[string]any + if isPreview { + legacyBody, err = runtime.MuEdBuildLegacyPreviewRequest(muEdReq) + } else { + legacyBody, err = runtime.MuEdBuildLegacyEvalRequest(muEdReq) + } + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + legacyBodyBytes, err := json.Marshal(legacyBody) + if err != nil { + http.Error(w, "failed to build request", http.StatusInternalServerError) + return + } + + command := runtime.CommandEvaluate + if isPreview { + command = runtime.CommandPreview + } + + header := http.Header{} + header.Set("Command", string(command)) + + req := runtime.Request{ + Path: r.URL.Path, + Method: http.MethodPost, + Body: legacyBodyBytes, + Header: header, + } + + resp := h.handler.Handle(r.Context(), req) + + if resp.StatusCode != http.StatusOK { + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + w.WriteHeader(resp.StatusCode) + w.Write(resp.Body) //nolint:errcheck + return + } + + var respBody map[string]any + if err := json.Unmarshal(resp.Body, &respBody); err != nil { + http.Error(w, "failed to parse response", http.StatusInternalServerError) + return + } + + result, ok := respBody["result"].(map[string]any) + if !ok { + http.Error(w, "invalid response from evaluation function", http.StatusInternalServerError) + return + } + + var feedback []map[string]any + if isPreview { + feedback = runtime.MuEdToPreviewFeedback(result) + } else { + feedback = runtime.MuEdToEvalFeedback(result) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(feedback) //nolint:errcheck +} + +// ServeHealth handles GET /evaluate/health. +func (h *MuEdHandler) ServeHealth(w http.ResponseWriter, r *http.Request) { + if !h.checkAuth(w, r) { + return + } + + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + resp, err := h.runtime.Handle(r.Context(), runtime.EvaluationRequest{ + Command: runtime.CommandHealth, + Data: map[string]any{}, + }) + if err != nil { + http.Error(w, "health check failed", http.StatusInternalServerError) + return + } + + result, ok := resp["result"] + if !ok { + http.Error(w, "invalid health response", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(result) //nolint:errcheck +} diff --git a/handler/mued_test.go b/handler/mued_test.go new file mode 100644 index 0000000..bcc905f --- /dev/null +++ b/handler/mued_test.go @@ -0,0 +1,325 @@ +package handler + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/lambda-feedback/shimmy/config" + "github.com/lambda-feedback/shimmy/runtime" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +// --- Mock runtime --- + +type MockRuntime struct { + mock.Mock +} + +func (m *MockRuntime) Handle(ctx context.Context, req runtime.EvaluationRequest) (runtime.EvaluationResponse, error) { + args := m.Called(ctx, req) + return args.Get(0).(runtime.EvaluationResponse), args.Error(1) +} + +func (m *MockRuntime) Start(ctx context.Context) error { + return m.Called(ctx).Error(0) +} + +func (m *MockRuntime) Shutdown(ctx context.Context) error { + return m.Called(ctx).Error(0) +} + +// --- Helpers --- + +func newMuEdHandler(h runtime.Handler, r runtime.Runtime, key string) *MuEdHandler { + return &MuEdHandler{ + handler: h, + runtime: r, + config: config.Config{Auth: config.AuthConfig{Key: key}}, + log: zap.NewNop(), + } +} + +func mathEvalBody(t *testing.T) []byte { + t.Helper() + b, err := json.Marshal(map[string]any{ + "submission": map[string]any{ + "type": "MATH", + "content": map[string]any{"expression": "x^2"}, + }, + "task": map[string]any{ + "referenceSolution": map[string]any{ + "type": "MATH", + "content": map[string]any{"expression": "x^2"}, + }, + }, + }) + require.NoError(t, err) + return b +} + +func evalHandlerResponse(isCorrect bool, feedback string) runtime.Response { + body, _ := json.Marshal(map[string]any{ + "command": "eval", + "result": map[string]any{ + "is_correct": isCorrect, + "feedback": feedback, + }, + }) + return runtime.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: body, + } +} + +// --- ServeEvaluate tests --- + +func TestMuEdServeEvaluate_Success(t *testing.T) { + mockHandler := new(MockHandler) + mockHandler.On("Handle", mock.Anything, mock.Anything). + Return(evalHandlerResponse(true, "Well done")) + + req := httptest.NewRequest(http.MethodPost, "/evaluate", bytes.NewReader(mathEvalBody(t))) + w := httptest.NewRecorder() + + newMuEdHandler(mockHandler, nil, "").ServeEvaluate(w, req) + + res := w.Result() + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, "application/json", res.Header.Get("Content-Type")) + + var feedback []map[string]any + require.NoError(t, json.Unmarshal(body, &feedback)) + require.Len(t, feedback, 1) + assert.Equal(t, 1.0, feedback[0]["awardedPoints"]) + assert.Equal(t, "Well done", feedback[0]["message"]) + assert.Contains(t, string(body), `"responseLatex":null`) + assert.Contains(t, string(body), `"responseSimplified":null`) + + mockHandler.AssertExpectations(t) +} + +func TestMuEdServeEvaluate_LegacyBodyForwarded(t *testing.T) { + mockHandler := new(MockHandler) + mockHandler.On("Handle", mock.Anything, mock.MatchedBy(func(r runtime.Request) bool { + var body map[string]any + if err := json.Unmarshal(r.Body, &body); err != nil { + return false + } + return body["response"] == "x^2" && + body["answer"] == "x^2" && + r.Header.Get("Command") == "eval" + })).Return(evalHandlerResponse(true, "Correct")) + + req := httptest.NewRequest(http.MethodPost, "/evaluate", bytes.NewReader(mathEvalBody(t))) + w := httptest.NewRecorder() + + newMuEdHandler(mockHandler, nil, "").ServeEvaluate(w, req) + + assert.Equal(t, http.StatusOK, w.Result().StatusCode) + mockHandler.AssertExpectations(t) +} + +func TestMuEdServeEvaluate_Preview(t *testing.T) { + previewBody, _ := json.Marshal(map[string]any{ + "submission": map[string]any{ + "type": "MATH", + "content": map[string]any{"expression": "x^2"}, + }, + "preSubmissionFeedback": map[string]any{"enabled": true}, + }) + + previewResult := map[string]any{"preview": map[string]any{"latex": "x^{2}"}} + respBody, _ := json.Marshal(map[string]any{ + "command": "preview", + "result": previewResult, + }) + mockHandler := new(MockHandler) + mockHandler.On("Handle", mock.Anything, mock.MatchedBy(func(r runtime.Request) bool { + return r.Header.Get("Command") == "preview" + })).Return(runtime.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: respBody, + }) + + req := httptest.NewRequest(http.MethodPost, "/evaluate", bytes.NewReader(previewBody)) + w := httptest.NewRecorder() + + newMuEdHandler(mockHandler, nil, "").ServeEvaluate(w, req) + + res := w.Result() + defer res.Body.Close() + raw, _ := io.ReadAll(res.Body) + + assert.Equal(t, http.StatusOK, res.StatusCode) + + var feedback []map[string]any + require.NoError(t, json.Unmarshal(raw, &feedback)) + require.Len(t, feedback, 1) + assert.NotNil(t, feedback[0]["preSubmissionFeedback"]) + + mockHandler.AssertExpectations(t) +} + +func TestMuEdServeEvaluate_Unauthorized(t *testing.T) { + mockHandler := new(MockHandler) + + req := httptest.NewRequest(http.MethodPost, "/evaluate", bytes.NewReader(mathEvalBody(t))) + req.Header.Set("api-key", "wrong") + w := httptest.NewRecorder() + + newMuEdHandler(mockHandler, nil, "secret").ServeEvaluate(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Result().StatusCode) + mockHandler.AssertNotCalled(t, "Handle", mock.Anything, mock.Anything) +} + +func TestMuEdServeEvaluate_MethodNotAllowed(t *testing.T) { + mockHandler := new(MockHandler) + + req := httptest.NewRequest(http.MethodGet, "/evaluate", nil) + w := httptest.NewRecorder() + + newMuEdHandler(mockHandler, nil, "").ServeEvaluate(w, req) + + assert.Equal(t, http.StatusMethodNotAllowed, w.Result().StatusCode) + mockHandler.AssertNotCalled(t, "Handle", mock.Anything, mock.Anything) +} + +func TestMuEdServeEvaluate_InvalidJSON(t *testing.T) { + mockHandler := new(MockHandler) + + req := httptest.NewRequest(http.MethodPost, "/evaluate", bytes.NewReader([]byte("not json"))) + w := httptest.NewRecorder() + + newMuEdHandler(mockHandler, nil, "").ServeEvaluate(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Result().StatusCode) + mockHandler.AssertNotCalled(t, "Handle", mock.Anything, mock.Anything) +} + +func TestMuEdServeEvaluate_MissingReferenceSolution(t *testing.T) { + mockHandler := new(MockHandler) + + body, _ := json.Marshal(map[string]any{ + "submission": map[string]any{ + "type": "MATH", + "content": map[string]any{"expression": "x^2"}, + }, + }) + req := httptest.NewRequest(http.MethodPost, "/evaluate", bytes.NewReader(body)) + w := httptest.NewRecorder() + + newMuEdHandler(mockHandler, nil, "").ServeEvaluate(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Result().StatusCode) + mockHandler.AssertNotCalled(t, "Handle", mock.Anything, mock.Anything) +} + +func TestMuEdServeEvaluate_WorkerErrorForwarded(t *testing.T) { + errorBody, _ := json.Marshal(map[string]any{ + "error": map[string]any{"message": "evaluation failed"}, + }) + mockHandler := new(MockHandler) + mockHandler.On("Handle", mock.Anything, mock.Anything).Return(runtime.Response{ + StatusCode: http.StatusInternalServerError, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: errorBody, + }) + + req := httptest.NewRequest(http.MethodPost, "/evaluate", bytes.NewReader(mathEvalBody(t))) + w := httptest.NewRecorder() + + newMuEdHandler(mockHandler, nil, "").ServeEvaluate(w, req) + + res := w.Result() + defer res.Body.Close() + raw, _ := io.ReadAll(res.Body) + + assert.Equal(t, http.StatusInternalServerError, res.StatusCode) + assert.Equal(t, errorBody, bytes.TrimRight(raw, "\n")) +} + +// --- ServeHealth tests --- + +func TestMuEdServeHealth_Success(t *testing.T) { + healthResult := map[string]any{"tests_passed": true, "successes": []any{}, "failures": []any{}, "errors": []any{}} + mockRuntime := new(MockRuntime) + mockRuntime.On("Handle", mock.Anything, runtime.EvaluationRequest{ + Command: runtime.CommandHealth, + Data: map[string]any{}, + }).Return(runtime.EvaluationResponse{ + "command": "healthcheck", + "result": healthResult, + }, nil) + + req := httptest.NewRequest(http.MethodGet, "/evaluate/health", nil) + w := httptest.NewRecorder() + + newMuEdHandler(nil, mockRuntime, "").ServeHealth(w, req) + + res := w.Result() + defer res.Body.Close() + raw, _ := io.ReadAll(res.Body) + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, "application/json", res.Header.Get("Content-Type")) + + var result map[string]any + require.NoError(t, json.Unmarshal(raw, &result)) + assert.Equal(t, true, result["tests_passed"]) + + mockRuntime.AssertExpectations(t) +} + +func TestMuEdServeHealth_Unauthorized(t *testing.T) { + mockRuntime := new(MockRuntime) + + req := httptest.NewRequest(http.MethodGet, "/evaluate/health", nil) + req.Header.Set("api-key", "wrong") + w := httptest.NewRecorder() + + newMuEdHandler(nil, mockRuntime, "secret").ServeHealth(w, req) + + assert.Equal(t, http.StatusUnauthorized, w.Result().StatusCode) + mockRuntime.AssertNotCalled(t, "Handle", mock.Anything, mock.Anything) +} + +func TestMuEdServeHealth_MethodNotAllowed(t *testing.T) { + mockRuntime := new(MockRuntime) + + req := httptest.NewRequest(http.MethodPost, "/evaluate/health", nil) + w := httptest.NewRecorder() + + newMuEdHandler(nil, mockRuntime, "").ServeHealth(w, req) + + assert.Equal(t, http.StatusMethodNotAllowed, w.Result().StatusCode) + mockRuntime.AssertNotCalled(t, "Handle", mock.Anything, mock.Anything) +} + +func TestMuEdServeHealth_RuntimeError(t *testing.T) { + mockRuntime := new(MockRuntime) + mockRuntime.On("Handle", mock.Anything, mock.Anything). + Return(runtime.EvaluationResponse{}, errors.New("worker unavailable")) + + req := httptest.NewRequest(http.MethodGet, "/evaluate/health", nil) + w := httptest.NewRecorder() + + newMuEdHandler(nil, mockRuntime, "").ServeHealth(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Result().StatusCode) + mockRuntime.AssertExpectations(t) +} diff --git a/handler/routes.go b/handler/routes.go index b447ecd..e5ac936 100644 --- a/handler/routes.go +++ b/handler/routes.go @@ -17,3 +17,11 @@ func NewCommandRoute(handler *CommandHandler) server.HttpHandlerResult { func NewHealthRoute() server.HttpHandlerResult { return server.AsHttpHandler("/health", http.HandlerFunc(HealthHandler)) } + +func NewMuEdEvaluateRoute(handler *MuEdHandler) server.HttpHandlerResult { + return server.AsHttpHandler("POST /evaluate", http.HandlerFunc(handler.ServeEvaluate)) +} + +func NewMuEdHealthRoute(handler *MuEdHandler) server.HttpHandlerResult { + return server.AsHttpHandler("GET /evaluate/health", http.HandlerFunc(handler.ServeHealth)) +} diff --git a/runtime/mued.go b/runtime/mued.go new file mode 100644 index 0000000..636c93f --- /dev/null +++ b/runtime/mued.go @@ -0,0 +1,161 @@ +package runtime + +import "fmt" + +type MuEdSubmissionType string + +const ( + MuEdMath MuEdSubmissionType = "MATH" + MuEdText MuEdSubmissionType = "TEXT" + MuEdCode MuEdSubmissionType = "CODE" + MuEdModel MuEdSubmissionType = "MODEL" + MuEdOther MuEdSubmissionType = "OTHER" +) + +type MuEdSubmission struct { + Type MuEdSubmissionType `json:"type"` + Content map[string]any `json:"content"` +} + +type MuEdTask struct { + ReferenceSolution *MuEdSubmission `json:"referenceSolution"` +} + +type MuEdConfiguration struct { + Params map[string]any `json:"params"` +} + +type MuEdPreSubmissionFeedback struct { + Enabled bool `json:"enabled"` +} + +type MuEdEvaluateRequest struct { + Submission MuEdSubmission `json:"submission"` + Task *MuEdTask `json:"task"` + Configuration *MuEdConfiguration `json:"configuration"` + PreSubmissionFeedback *MuEdPreSubmissionFeedback `json:"preSubmissionFeedback"` +} + +func muEdContentKey(t MuEdSubmissionType) string { + switch t { + case MuEdMath: + return "expression" + case MuEdText: + return "text" + case MuEdCode: + return "code" + case MuEdModel: + return "model" + default: + return "value" + } +} + +func muEdExtractContent(content map[string]any, t MuEdSubmissionType) (any, error) { + key := muEdContentKey(t) + if val, ok := content[key]; ok { + return val, nil + } + if t != MuEdOther { + if val, ok := content["value"]; ok { + return val, nil + } + } + return nil, fmt.Errorf("could not extract content for submission type %s", t) +} + +func muEdExtractParams(req MuEdEvaluateRequest) map[string]any { + if req.Configuration != nil && req.Configuration.Params != nil { + return req.Configuration.Params + } + return map[string]any{} +} + +// MuEdBuildLegacyEvalRequest builds {response, answer, params} for the eval command. +func MuEdBuildLegacyEvalRequest(req MuEdEvaluateRequest) (map[string]any, error) { + response, err := muEdExtractContent(req.Submission.Content, req.Submission.Type) + if err != nil { + return nil, fmt.Errorf("submission: %w", err) + } + + if req.Task == nil || req.Task.ReferenceSolution == nil { + return nil, fmt.Errorf("task.referenceSolution is required for evaluation") + } + + sol := req.Task.ReferenceSolution + answer, err := muEdExtractContent(sol.Content, sol.Type) + if err != nil { + return nil, fmt.Errorf("referenceSolution: %w", err) + } + + return map[string]any{ + "response": response, + "answer": answer, + "params": muEdExtractParams(req), + }, nil +} + +// MuEdBuildLegacyPreviewRequest builds {response, params} for the preview command. +func MuEdBuildLegacyPreviewRequest(req MuEdEvaluateRequest) (map[string]any, error) { + response, err := muEdExtractContent(req.Submission.Content, req.Submission.Type) + if err != nil { + return nil, fmt.Errorf("submission: %w", err) + } + + return map[string]any{ + "response": response, + "params": muEdExtractParams(req), + }, nil +} + +// MuEdToEvalFeedback transforms a legacy result map into a muEd Feedback array. +// responseLatex and responseSimplified are always present in the output (null when absent). +func MuEdToEvalFeedback(result map[string]any) []map[string]any { + feedback := map[string]any{ + "responseLatex": nil, + "responseSimplified": nil, + } + + if isCorrect, ok := result["is_correct"].(bool); ok { + pts := 0.0 + if isCorrect { + pts = 1.0 + } + feedback["awardedPoints"] = pts + } + + if msg, ok := result["feedback"].(string); ok { + feedback["message"] = msg + } + + if mc, ok := result["matched_case"].(float64); ok { + feedback["matchedCase"] = int(mc) + } + + if rl, ok := result["response_latex"].(string); ok { + feedback["responseLatex"] = rl + } + + if rs, ok := result["response_simplified"].(string); ok { + feedback["responseSimplified"] = rs + } + + if tags, ok := result["tags"].([]any); ok { + strTags := make([]string, 0, len(tags)) + for _, t := range tags { + if s, ok := t.(string); ok { + strTags = append(strTags, s) + } + } + feedback["tags"] = strTags + } + + return []map[string]any{feedback} +} + +// MuEdToPreviewFeedback wraps a legacy preview result as [{"preSubmissionFeedback": result}]. +func MuEdToPreviewFeedback(result map[string]any) []map[string]any { + return []map[string]any{ + {"preSubmissionFeedback": result}, + } +} diff --git a/runtime/mued_test.go b/runtime/mued_test.go new file mode 100644 index 0000000..15bd7b9 --- /dev/null +++ b/runtime/mued_test.go @@ -0,0 +1,299 @@ +package runtime_test + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/lambda-feedback/shimmy/runtime" +) + +func TestMuEdContentKey(t *testing.T) { + cases := []struct { + t runtime.MuEdSubmissionType + want string + }{ + {runtime.MuEdMath, "expression"}, + {runtime.MuEdText, "text"}, + {runtime.MuEdCode, "code"}, + {runtime.MuEdModel, "model"}, + {runtime.MuEdOther, "value"}, + {"UNKNOWN", "value"}, + } + + for _, tc := range cases { + t.Run(string(tc.t), func(t *testing.T) { + // contentKey is unexported; exercise it via BuildLegacyEvalRequest + req := runtime.MuEdEvaluateRequest{ + Submission: runtime.MuEdSubmission{ + Type: tc.t, + Content: map[string]any{tc.want: "x"}, + }, + Task: &runtime.MuEdTask{ + ReferenceSolution: &runtime.MuEdSubmission{ + Type: tc.t, + Content: map[string]any{tc.want: "x"}, + }, + }, + } + body, err := runtime.MuEdBuildLegacyEvalRequest(req) + require.NoError(t, err) + assert.Equal(t, "x", body["response"]) + }) + } +} + +func TestMuEdBuildLegacyEvalRequest(t *testing.T) { + t.Run("MATH primary key", func(t *testing.T) { + req := runtime.MuEdEvaluateRequest{ + Submission: runtime.MuEdSubmission{ + Type: runtime.MuEdMath, + Content: map[string]any{"expression": "x^2"}, + }, + Task: &runtime.MuEdTask{ + ReferenceSolution: &runtime.MuEdSubmission{ + Type: runtime.MuEdMath, + Content: map[string]any{"expression": "x^2"}, + }, + }, + } + body, err := runtime.MuEdBuildLegacyEvalRequest(req) + require.NoError(t, err) + assert.Equal(t, "x^2", body["response"]) + assert.Equal(t, "x^2", body["answer"]) + assert.Equal(t, map[string]any{}, body["params"]) + }) + + t.Run("TEXT primary key", func(t *testing.T) { + req := runtime.MuEdEvaluateRequest{ + Submission: runtime.MuEdSubmission{ + Type: runtime.MuEdText, + Content: map[string]any{"text": "hello"}, + }, + Task: &runtime.MuEdTask{ + ReferenceSolution: &runtime.MuEdSubmission{ + Type: runtime.MuEdText, + Content: map[string]any{"text": "hello"}, + }, + }, + } + body, err := runtime.MuEdBuildLegacyEvalRequest(req) + require.NoError(t, err) + assert.Equal(t, "hello", body["response"]) + assert.Equal(t, "hello", body["answer"]) + }) + + t.Run("OTHER type uses value key", func(t *testing.T) { + req := runtime.MuEdEvaluateRequest{ + Submission: runtime.MuEdSubmission{ + Type: runtime.MuEdOther, + Content: map[string]any{"value": "foo"}, + }, + Task: &runtime.MuEdTask{ + ReferenceSolution: &runtime.MuEdSubmission{ + Type: runtime.MuEdOther, + Content: map[string]any{"value": "bar"}, + }, + }, + } + body, err := runtime.MuEdBuildLegacyEvalRequest(req) + require.NoError(t, err) + assert.Equal(t, "foo", body["response"]) + assert.Equal(t, "bar", body["answer"]) + }) + + t.Run("missing primary key falls back to value", func(t *testing.T) { + req := runtime.MuEdEvaluateRequest{ + Submission: runtime.MuEdSubmission{ + Type: runtime.MuEdMath, + Content: map[string]any{"value": "x^2"}, + }, + Task: &runtime.MuEdTask{ + ReferenceSolution: &runtime.MuEdSubmission{ + Type: runtime.MuEdMath, + Content: map[string]any{"value": "x^2"}, + }, + }, + } + body, err := runtime.MuEdBuildLegacyEvalRequest(req) + require.NoError(t, err) + assert.Equal(t, "x^2", body["response"]) + assert.Equal(t, "x^2", body["answer"]) + }) + + t.Run("missing both keys returns error", func(t *testing.T) { + req := runtime.MuEdEvaluateRequest{ + Submission: runtime.MuEdSubmission{ + Type: runtime.MuEdMath, + Content: map[string]any{"unrelated": "x"}, + }, + Task: &runtime.MuEdTask{ + ReferenceSolution: &runtime.MuEdSubmission{ + Type: runtime.MuEdMath, + Content: map[string]any{"expression": "x"}, + }, + }, + } + _, err := runtime.MuEdBuildLegacyEvalRequest(req) + require.Error(t, err) + }) + + t.Run("nil task returns error", func(t *testing.T) { + req := runtime.MuEdEvaluateRequest{ + Submission: runtime.MuEdSubmission{ + Type: runtime.MuEdMath, + Content: map[string]any{"expression": "x^2"}, + }, + Task: nil, + } + _, err := runtime.MuEdBuildLegacyEvalRequest(req) + require.Error(t, err) + }) + + t.Run("nil reference solution returns error", func(t *testing.T) { + req := runtime.MuEdEvaluateRequest{ + Submission: runtime.MuEdSubmission{ + Type: runtime.MuEdMath, + Content: map[string]any{"expression": "x^2"}, + }, + Task: &runtime.MuEdTask{ReferenceSolution: nil}, + } + _, err := runtime.MuEdBuildLegacyEvalRequest(req) + require.Error(t, err) + }) + + t.Run("params forwarded from configuration", func(t *testing.T) { + req := runtime.MuEdEvaluateRequest{ + Submission: runtime.MuEdSubmission{ + Type: runtime.MuEdMath, + Content: map[string]any{"expression": "x"}, + }, + Task: &runtime.MuEdTask{ + ReferenceSolution: &runtime.MuEdSubmission{ + Type: runtime.MuEdMath, + Content: map[string]any{"expression": "x"}, + }, + }, + Configuration: &runtime.MuEdConfiguration{ + Params: map[string]any{"strict": true}, + }, + } + body, err := runtime.MuEdBuildLegacyEvalRequest(req) + require.NoError(t, err) + assert.Equal(t, map[string]any{"strict": true}, body["params"]) + }) +} + +func TestMuEdBuildLegacyPreviewRequest(t *testing.T) { + t.Run("extracts response only", func(t *testing.T) { + req := runtime.MuEdEvaluateRequest{ + Submission: runtime.MuEdSubmission{ + Type: runtime.MuEdMath, + Content: map[string]any{"expression": "x^2"}, + }, + } + body, err := runtime.MuEdBuildLegacyPreviewRequest(req) + require.NoError(t, err) + assert.Equal(t, "x^2", body["response"]) + _, hasAnswer := body["answer"] + assert.False(t, hasAnswer) + }) + + t.Run("params forwarded", func(t *testing.T) { + req := runtime.MuEdEvaluateRequest{ + Submission: runtime.MuEdSubmission{ + Type: runtime.MuEdText, + Content: map[string]any{"text": "hi"}, + }, + Configuration: &runtime.MuEdConfiguration{ + Params: map[string]any{"lang": "en"}, + }, + } + body, err := runtime.MuEdBuildLegacyPreviewRequest(req) + require.NoError(t, err) + assert.Equal(t, map[string]any{"lang": "en"}, body["params"]) + }) + + t.Run("nil configuration gives empty params", func(t *testing.T) { + req := runtime.MuEdEvaluateRequest{ + Submission: runtime.MuEdSubmission{ + Type: runtime.MuEdCode, + Content: map[string]any{"code": "print()"}, + }, + Configuration: nil, + } + body, err := runtime.MuEdBuildLegacyPreviewRequest(req) + require.NoError(t, err) + assert.Equal(t, map[string]any{}, body["params"]) + }) +} + +func TestMuEdToEvalFeedback(t *testing.T) { + t.Run("is_correct true gives awardedPoints 1", func(t *testing.T) { + result := map[string]any{"is_correct": true, "feedback": "Well done"} + fb := runtime.MuEdToEvalFeedback(result) + require.Len(t, fb, 1) + assert.Equal(t, 1.0, fb[0]["awardedPoints"]) + assert.Equal(t, "Well done", fb[0]["message"]) + }) + + t.Run("is_correct false gives awardedPoints 0", func(t *testing.T) { + result := map[string]any{"is_correct": false, "feedback": "Try again"} + fb := runtime.MuEdToEvalFeedback(result) + require.Len(t, fb, 1) + assert.Equal(t, 0.0, fb[0]["awardedPoints"]) + }) + + t.Run("matched_case mapped to matchedCase int", func(t *testing.T) { + result := map[string]any{"is_correct": false, "matched_case": float64(2)} + fb := runtime.MuEdToEvalFeedback(result) + assert.Equal(t, 2, fb[0]["matchedCase"]) + }) + + t.Run("responseLatex present", func(t *testing.T) { + result := map[string]any{"is_correct": true, "response_latex": "x^{2}"} + fb := runtime.MuEdToEvalFeedback(result) + assert.Equal(t, "x^{2}", fb[0]["responseLatex"]) + }) + + t.Run("responseLatex absent is null in JSON", func(t *testing.T) { + result := map[string]any{"is_correct": true} + fb := runtime.MuEdToEvalFeedback(result) + assert.Nil(t, fb[0]["responseLatex"]) + + raw, err := json.Marshal(fb) + require.NoError(t, err) + assert.Contains(t, string(raw), `"responseLatex":null`) + }) + + t.Run("responseSimplified present", func(t *testing.T) { + result := map[string]any{"is_correct": true, "response_simplified": "x^2"} + fb := runtime.MuEdToEvalFeedback(result) + assert.Equal(t, "x^2", fb[0]["responseSimplified"]) + }) + + t.Run("responseSimplified absent is null in JSON", func(t *testing.T) { + result := map[string]any{"is_correct": true} + fb := runtime.MuEdToEvalFeedback(result) + assert.Nil(t, fb[0]["responseSimplified"]) + + raw, err := json.Marshal(fb) + require.NoError(t, err) + assert.Contains(t, string(raw), `"responseSimplified":null`) + }) + + t.Run("tags mapped", func(t *testing.T) { + result := map[string]any{"is_correct": true, "tags": []any{"algebra", "calculus"}} + fb := runtime.MuEdToEvalFeedback(result) + assert.Equal(t, []string{"algebra", "calculus"}, fb[0]["tags"]) + }) +} + +func TestMuEdToPreviewFeedback(t *testing.T) { + result := map[string]any{"preview": map[string]any{"latex": "x^2"}} + fb := runtime.MuEdToPreviewFeedback(result) + require.Len(t, fb, 1) + assert.Equal(t, result, fb[0]["preSubmissionFeedback"]) +} From 18fcd3c66a3162070ded901e893ab492e20dee63 Mon Sep 17 00:00:00 2001 From: Marcus Messer Date: Wed, 6 May 2026 13:50:11 +0100 Subject: [PATCH 2/4] Added `workflow_dispatch` trigger to GitHub Actions build workflow --- .github/workflows/build.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0330fdd..67ab210 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,6 +7,8 @@ on: branches: - main pull_request: + workflow_dispatch: + jobs: build: From 5ef5e025490c7df5c6f77e47bf69c33b13c263d9 Mon Sep 17 00:00:00 2001 From: Marcus Messer Date: Thu, 7 May 2026 08:52:20 +0100 Subject: [PATCH 3/4] Removed `NewCommandRoute` and corrected route definitions for `/evaluate` and `/evaluate/health` --- handler/module.go | 1 - handler/routes.go | 8 ++------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/handler/module.go b/handler/module.go index cea7455..a5002f7 100644 --- a/handler/module.go +++ b/handler/module.go @@ -7,7 +7,6 @@ func Module() fx.Option { fx.Provide(NewCommandHandler), fx.Provide(NewMuEdHandler), fx.Provide(NewLegacyRoute), - fx.Provide(NewCommandRoute), fx.Provide(NewHealthRoute), fx.Provide(NewMuEdEvaluateRoute), fx.Provide(NewMuEdHealthRoute), diff --git a/handler/routes.go b/handler/routes.go index e5ac936..75dbad1 100644 --- a/handler/routes.go +++ b/handler/routes.go @@ -10,18 +10,14 @@ func NewLegacyRoute(handler *CommandHandler) server.HttpHandlerResult { return server.AsHttpHandler("/", handler) } -func NewCommandRoute(handler *CommandHandler) server.HttpHandlerResult { - return server.AsHttpHandler("/{command}", handler) -} - func NewHealthRoute() server.HttpHandlerResult { return server.AsHttpHandler("/health", http.HandlerFunc(HealthHandler)) } func NewMuEdEvaluateRoute(handler *MuEdHandler) server.HttpHandlerResult { - return server.AsHttpHandler("POST /evaluate", http.HandlerFunc(handler.ServeEvaluate)) + return server.AsHttpHandler("/evaluate", http.HandlerFunc(handler.ServeEvaluate)) } func NewMuEdHealthRoute(handler *MuEdHandler) server.HttpHandlerResult { - return server.AsHttpHandler("GET /evaluate/health", http.HandlerFunc(handler.ServeHealth)) + return server.AsHttpHandler("/evaluate/health", http.HandlerFunc(handler.ServeHealth)) } From afcb66265cafe09b14e2497d406c980a5bf1fd2b Mon Sep 17 00:00:00 2001 From: Marcus Messer Date: Fri, 8 May 2026 11:27:58 +0100 Subject: [PATCH 4/4] Added `NormalizePath` middleware to canonicalize `/evaluate` and `/evaluate/health` paths across server and lambda integrations --- app/lambda/handler.go | 6 +++--- internal/server/middleware.go | 25 +++++++++++++++++++++++++ internal/server/server.go | 4 ++-- 3 files changed, 30 insertions(+), 5 deletions(-) create mode 100644 internal/server/middleware.go diff --git a/app/lambda/handler.go b/app/lambda/handler.go index 9deedf4..c230576 100644 --- a/app/lambda/handler.go +++ b/app/lambda/handler.go @@ -101,11 +101,11 @@ func (s *LambdaHandler) Shutdown() { func (s *LambdaHandler) getProxyFunction() (any, error) { switch s.config.ProxySource { case ProxySourceApiGatewayV1: - return httpadapter.New(s.mux).ProxyWithContext, nil + return httpadapter.New(server.NormalizePath(s.mux)).ProxyWithContext, nil case ProxySourceApiGatewayV2: - return httpadapter.NewV2(s.mux).ProxyWithContext, nil + return httpadapter.NewV2(server.NormalizePath(s.mux)).ProxyWithContext, nil case ProxySourceAlb: - return httpadapter.NewALB(s.mux).ProxyWithContext, nil + return httpadapter.NewALB(server.NormalizePath(s.mux)).ProxyWithContext, nil default: return nil, fmt.Errorf("invalid proxy source: %s", s.config.ProxySource) } diff --git a/internal/server/middleware.go b/internal/server/middleware.go new file mode 100644 index 0000000..381ad48 --- /dev/null +++ b/internal/server/middleware.go @@ -0,0 +1,25 @@ +package server + +import ( + "net/http" + "strings" +) + +// NormalizePath rewrites request paths that end in /evaluate or /evaluate/health +// to their canonical forms before the mux routes them. This mirrors the Python +// BaseEvaluationFunctionLayer behaviour: API Gateway forwards the full +// function-name prefix (e.g. /compareSets/evaluate), but shimmy only registers +// /evaluate. +func NormalizePath(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(r.URL.Path, "/evaluate/health"): + r = r.Clone(r.Context()) + r.URL.Path = "/evaluate/health" + case strings.HasSuffix(r.URL.Path, "/evaluate"): + r = r.Clone(r.Context()) + r.URL.Path = "/evaluate" + } + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/server.go b/internal/server/server.go index 6fe7316..883be2f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -38,9 +38,9 @@ func NewHttpServer(params HttpServerParams) *HttpServer { mux.Handle(handler.Name, handler.Handler) } - var handler http.Handler = mux + var handler http.Handler = NormalizePath(mux) if params.Config.H2c { - handler = h2c.NewHandler(mux, &http2.Server{}) + handler = h2c.NewHandler(NormalizePath(mux), &http2.Server{}) } server := &http.Server{