diff --git a/src/__tests__/buttonHandler.test.ts b/src/__tests__/buttonHandler.test.ts new file mode 100644 index 0000000..166f232 --- /dev/null +++ b/src/__tests__/buttonHandler.test.ts @@ -0,0 +1,106 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; + +const sessionManagerMock = vi.hoisted(() => ({ + getSessionForThread: vi.fn(), + listQuestions: vi.fn(), + replyQuestion: vi.fn(), + rejectQuestion: vi.fn(), + abortSession: vi.fn(), + ensureSessionForThread: vi.fn(), + sendPrompt: vi.fn(), +})); + +vi.mock("../services/sessionManager.js", () => sessionManagerMock); +vi.mock("../services/serveManager.js", () => ({ + getPort: vi.fn(), + spawnServe: vi.fn(), + waitForReady: vi.fn(), +})); +vi.mock("../services/dataStore.js", () => ({ + getChannelModel: vi.fn(), + getWorktreeMapping: vi.fn(), + removeWorktreeMapping: vi.fn(), +})); +vi.mock("../services/worktreeManager.js", () => ({ + worktreeExists: vi.fn(), + removeWorktree: vi.fn(), +})); + +import { handleButton } from "../handlers/buttonHandler.js"; + +function mockInteraction(customId: string) { + return { + customId, + reply: vi.fn(), + deferReply: vi.fn(), + editReply: vi.fn(), + channel: { id: "channel-1", isThread: () => false }, + } as any; +} + +describe("handleButton question responses", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("answers OpenCode questions with the selected option", async () => { + sessionManagerMock.getSessionForThread.mockReturnValue({ + sessionId: "ses_123", + projectPath: "/repo", + port: 14098, + }); + sessionManagerMock.listQuestions.mockResolvedValue([ + { + id: "que_dfcfdc0e70013EvGpyc0soVaR7", + sessionID: "ses_123", + questions: [ + { + question: "Approve this plan?", + options: [{ label: "Approve plan" }, { label: "Revise plan" }], + }, + ], + }, + ]); + sessionManagerMock.replyQuestion.mockResolvedValue(true); + + const interaction = mockInteraction( + "qanswer:thread123:que_dfcfdc0e70013EvGpyc0soVaR7:0", + ); + + await handleButton(interaction); + + expect(interaction.deferReply).toHaveBeenCalled(); + expect(sessionManagerMock.replyQuestion).toHaveBeenCalledWith( + 14098, + "que_dfcfdc0e70013EvGpyc0soVaR7", + [["Approve plan"]], + ); + expect(interaction.editReply).toHaveBeenCalledWith({ + content: "βœ… Sent response: Approve plan", + }); + }); + + it("rejects OpenCode questions", async () => { + sessionManagerMock.getSessionForThread.mockReturnValue({ + sessionId: "ses_123", + projectPath: "/repo", + port: 14098, + }); + sessionManagerMock.rejectQuestion.mockResolvedValue(true); + + const interaction = mockInteraction( + "qreject:thread123:que_dfcfdc0e70013EvGpyc0soVaR7", + ); + + await handleButton(interaction); + + expect(interaction.deferReply).toHaveBeenCalled(); + expect(sessionManagerMock.rejectQuestion).toHaveBeenCalledWith( + 14098, + "que_dfcfdc0e70013EvGpyc0soVaR7", + ); + expect(interaction.editReply).toHaveBeenCalledWith({ + content: "🚫 Question rejected.", + }); + }); +}); diff --git a/src/__tests__/sessionManager.test.ts b/src/__tests__/sessionManager.test.ts index 64e4f6a..ff6d52b 100644 --- a/src/__tests__/sessionManager.test.ts +++ b/src/__tests__/sessionManager.test.ts @@ -53,6 +53,9 @@ import { getSessionInfo, listSessions, abortSession, + listQuestions, + replyQuestion, + rejectQuestion, ensureSessionForThread, getSessionForThread, setSessionForThread, @@ -190,6 +193,54 @@ describe("SessionManager", () => { }); }); + describe("question helpers", () => { + it("should list pending questions", async () => { + const questions = [{ id: "que_123", sessionID: "ses_123", questions: [] }]; + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => questions, + }); + + await expect(listQuestions(3000)).resolves.toEqual(questions); + + expect(mockFetch).toHaveBeenCalledWith("http://127.0.0.1:3000/question", { + method: "GET", + headers: {}, + }); + }); + + it("should reply to a pending question", async () => { + mockFetch.mockResolvedValueOnce({ ok: true }); + + await expect( + replyQuestion(3000, "que_123", [["Approve plan"]]), + ).resolves.toBe(true); + + expect(mockFetch).toHaveBeenCalledWith( + "http://127.0.0.1:3000/question/que_123/reply", + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ answers: [["Approve plan"]] }), + }, + ); + }); + + it("should reject a pending question", async () => { + mockFetch.mockResolvedValueOnce({ ok: true }); + + await expect(rejectQuestion(3000, "que_123")).resolves.toBe(true); + + expect(mockFetch).toHaveBeenCalledWith( + "http://127.0.0.1:3000/question/que_123/reject", + { + method: "POST", + headers: {}, + }, + ); + }); + }); + describe("thread-session mapping", () => { it("should store and retrieve session for thread", () => { setSessionForThread("thread1", "ses_123", "/path/to/project", 4000); diff --git a/src/__tests__/sseClient.test.ts b/src/__tests__/sseClient.test.ts index 441f75e..7766e99 100644 --- a/src/__tests__/sseClient.test.ts +++ b/src/__tests__/sseClient.test.ts @@ -336,6 +336,60 @@ describe("SSEClient", () => { }); }); + describe("onQuestionAsked", () => { + it("should trigger callback for question.asked events", () => { + const callback = vi.fn(); + client.connect("http://127.0.0.1:3000"); + client.onQuestionAsked(callback); + + const messageHandler = + mockEventSourceInstance.addEventListener.mock.calls.find( + (call: any) => call[0] === "message", + )?.[1]; + + const request = { + id: "que_123", + sessionID: "session-1", + questions: [ + { + header: "Plan Approval", + question: "Approve this plan?", + options: [{ label: "Approve plan" }, { label: "Revise plan" }], + }, + ], + }; + + messageHandler({ + data: JSON.stringify({ + type: "question.asked", + properties: request, + }), + }); + + expect(callback).toHaveBeenCalledWith(request); + }); + + it("should not trigger callback for malformed question.asked events", () => { + const callback = vi.fn(); + client.connect("http://127.0.0.1:3000"); + client.onQuestionAsked(callback); + + const messageHandler = + mockEventSourceInstance.addEventListener.mock.calls.find( + (call: any) => call[0] === "message", + )?.[1]; + + messageHandler({ + data: JSON.stringify({ + type: "question.asked", + properties: { id: "que_123" }, + }), + }); + + expect(callback).not.toHaveBeenCalled(); + }); + }); + describe("onError", () => { it("should trigger callback on error", () => { const callback = vi.fn(); diff --git a/src/handlers/buttonHandler.ts b/src/handlers/buttonHandler.ts index d282d2b..cd508c1 100644 --- a/src/handlers/buttonHandler.ts +++ b/src/handlers/buttonHandler.ts @@ -1,4 +1,5 @@ import { ButtonInteraction, ThreadChannel, MessageFlags } from 'discord.js'; +import type { QuestionRequest } from '../types/index.js'; import * as sessionManager from '../services/sessionManager.js'; import * as serveManager from '../services/serveManager.js'; import * as dataStore from '../services/dataStore.js'; @@ -6,6 +7,18 @@ import * as worktreeManager from '../services/worktreeManager.js'; export async function handleButton(interaction: ButtonInteraction) { const customId = interaction.customId; + + if (customId.startsWith('qanswer:')) { + const [, threadId, requestId, optionIndexRaw] = customId.split(':'); + await handleQuestionAnswer(interaction, threadId, requestId, optionIndexRaw); + return; + } + + if (customId.startsWith('qreject:')) { + const [, threadId, requestId] = customId.split(':'); + await handleQuestionReject(interaction, threadId, requestId); + return; + } const [action, threadId] = customId.split('_'); @@ -67,6 +80,90 @@ async function handleInterrupt(interaction: ButtonInteraction, threadId: string) } } + +async function handleQuestionAnswer( + interaction: ButtonInteraction, + threadId: string | undefined, + requestId: string | undefined, + optionIndexRaw: string | undefined, +) { + const optionIndex = Number(optionIndexRaw); + + if (!threadId || !requestId || !Number.isInteger(optionIndex)) { + await interaction.reply({ + content: '❌ Invalid question response.', + flags: MessageFlags.Ephemeral, + }); + return; + } + + const session = sessionManager.getSessionForThread(threadId); + if (!session) { + await interaction.reply({ + content: '⚠️ Session not found.', + flags: MessageFlags.Ephemeral, + }); + return; + } + + await interaction.deferReply({ flags: MessageFlags.Ephemeral }); + + try { + const questions = (await sessionManager.listQuestions(session.port)) as QuestionRequest[]; + const request = questions.find((q) => q.id === requestId); + const question = request?.questions?.[0]; + const option = question?.options?.[optionIndex]; + + if (!option?.label) { + await interaction.editReply({ + content: '⚠️ Pending question/option not found. It may have already been answered.', + }); + return; + } + + await sessionManager.replyQuestion(session.port, requestId, [[option.label]]); + await interaction.editReply({ content: `βœ… Sent response: ${option.label}` }); + } catch (error) { + await interaction.editReply({ + content: `❌ Failed to answer question: ${(error as Error).message}`, + }); + } +} + +async function handleQuestionReject( + interaction: ButtonInteraction, + threadId: string | undefined, + requestId: string | undefined, +) { + if (!threadId || !requestId) { + await interaction.reply({ + content: '❌ Invalid question rejection.', + flags: MessageFlags.Ephemeral, + }); + return; + } + + const session = sessionManager.getSessionForThread(threadId); + if (!session) { + await interaction.reply({ + content: '⚠️ Session not found.', + flags: MessageFlags.Ephemeral, + }); + return; + } + + await interaction.deferReply({ flags: MessageFlags.Ephemeral }); + + try { + await sessionManager.rejectQuestion(session.port, requestId); + await interaction.editReply({ content: '🚫 Question rejected.' }); + } catch (error) { + await interaction.editReply({ + content: `❌ Failed to reject question: ${(error as Error).message}`, + }); + } +} + async function handleWorktreeDelete(interaction: ButtonInteraction, threadId: string) { const mapping = dataStore.getWorktreeMapping(threadId); if (!mapping) { diff --git a/src/services/executionService.ts b/src/services/executionService.ts index 25d9c41..70f4033 100644 --- a/src/services/executionService.ts +++ b/src/services/executionService.ts @@ -13,6 +13,7 @@ import * as worktreeManager from './worktreeManager.js'; import { SSEClient } from './sseClient.js'; import { formatOutput, formatOutputForMobile, buildContextHeader } from '../utils/messageFormatter.js'; import { processNextInQueue } from './queueManager.js'; +import type { QuestionRequest } from '../types/index.js'; export async function runPrompt( channel: TextBasedChannel, @@ -267,6 +268,48 @@ export async function runPrompt( })(); }); + sseClient.onQuestionAsked((request: QuestionRequest) => { + if (request.sessionID !== sessionId) return; + + if (updateInterval) { + clearInterval(updateInterval); + updateInterval = null; + } + + (async () => { + try { + const question = request.questions?.[0]; + const header = question?.header ? `**${question.header}**` : '**OpenCode needs input**'; + const body = question?.question ?? 'OpenCode is waiting for a response.'; + const optionButtons = (question?.options ?? []).slice(0, 4).map((option, index) => + new ButtonBuilder() + .setCustomId(`qanswer:${threadId}:${request.id}:${index}`) + .setLabel((option.label ?? `Option ${index + 1}`).slice(0, 80)) + .setStyle(index === 0 ? ButtonStyle.Primary : ButtonStyle.Secondary) + ); + const rejectButton = new ButtonBuilder() + .setCustomId(`qreject:${threadId}:${request.id}`) + .setLabel('Reject') + .setStyle(ButtonStyle.Danger); + const questionButtons = new ActionRowBuilder().addComponents( + ...optionButtons, + rejectButton, + ); + + const edited = await updateStreamMessage( + `${contextHeader}\nπŸ“Œ **Prompt**: ${prompt}\n\n⏸️ **Waiting for OpenCode input**\n${header}\n\n${body.slice(0, 1500)}`, + [questionButtons], + ); + if (!edited) { + await safeSend(`⏸️ OpenCode is waiting for input: ${header}`); + } + } catch (error) { + console.error('Error in onQuestionAsked:', error); + await safeSend('❌ OpenCode asked a question, but I could not render it in Discord.'); + } + })(); + }); + sseClient.onError((error) => { if (updateInterval) { clearInterval(updateInterval); diff --git a/src/services/sessionManager.ts b/src/services/sessionManager.ts index 31be14a..26c01a4 100644 --- a/src/services/sessionManager.ts +++ b/src/services/sessionManager.ts @@ -181,6 +181,67 @@ export async function abortSession( return response.ok; } + +export async function listQuestions(port: number): Promise { + const url = `http://127.0.0.1:${port}/question`; + const response = await fetch(url, { + method: "GET", + headers: getAuthHeaders(), + }); + + if (!response.ok) { + assertNotAuthError(response.status, "Failed to list questions"); + throw new Error(`Failed to list questions: ${response.status} ${response.statusText}`); + } + + const data = await response.json(); + return Array.isArray(data) ? data : []; +} + +export async function replyQuestion( + port: number, + requestId: string, + answers: string[][], +): Promise { + const url = `http://127.0.0.1:${port}/question/${requestId}/reply`; + const response = await fetch(url, { + method: "POST", + headers: jsonHeaders(), + body: JSON.stringify({ answers }), + }); + + if (!response.ok) { + const responseBody = await response.text(); + assertNotAuthError(response.status, "Failed to answer question"); + throw new Error( + `Failed to answer question: ${response.status} ${response.statusText} β€” ${responseBody}`, + ); + } + + return true; +} + +export async function rejectQuestion( + port: number, + requestId: string, +): Promise { + const url = `http://127.0.0.1:${port}/question/${requestId}/reject`; + const response = await fetch(url, { + method: "POST", + headers: getAuthHeaders(), + }); + + if (!response.ok) { + const responseBody = await response.text(); + assertNotAuthError(response.status, "Failed to reject question"); + throw new Error( + `Failed to reject question: ${response.status} ${response.statusText} β€” ${responseBody}`, + ); + } + + return true; +} + export function getSessionForThread( threadId: string, ): { sessionId: string; projectPath: string; port: number } | undefined { diff --git a/src/services/sseClient.ts b/src/services/sseClient.ts index 43682ae..f05191e 100644 --- a/src/services/sseClient.ts +++ b/src/services/sseClient.ts @@ -1,5 +1,5 @@ import { EventSource } from "eventsource"; -import type { TextPart, SSEEvent, SessionErrorInfo } from "../types/index.js"; +import type { TextPart, SSEEvent, SessionErrorInfo, QuestionRequest } from "../types/index.js"; import { getAuthHeaders } from "./serverAuth.js"; type PartUpdatedCallback = (part: TextPart) => void; @@ -8,6 +8,7 @@ type SessionErrorCallback = ( sessionId: string, error: SessionErrorInfo, ) => void; +type QuestionAskedCallback = (request: QuestionRequest) => void; type ErrorCallback = (error: Error) => void; export class SSEClient { @@ -15,6 +16,7 @@ export class SSEClient { private partUpdatedCallbacks: PartUpdatedCallback[] = []; private sessionIdleCallbacks: SessionIdleCallback[] = []; private sessionErrorCallbacks: SessionErrorCallback[] = []; + private questionAskedCallbacks: QuestionAskedCallback[] = []; private errorCallbacks: ErrorCallback[] = []; connect(baseUrl: string): void { @@ -64,6 +66,10 @@ export class SSEClient { this.sessionErrorCallbacks.push(callback); } + onQuestionAsked(callback: QuestionAskedCallback): void { + this.questionAskedCallbacks.push(callback); + } + onError(callback: ErrorCallback): void { this.errorCallbacks.push(callback); } @@ -107,6 +113,11 @@ export class SSEClient { if (sessionID && error) { this.sessionErrorCallbacks.forEach((cb) => cb(sessionID, error)); } + } else if (event.type === "question.asked") { + const request = event.properties as unknown as QuestionRequest; + if (request?.sessionID && request?.id) { + this.questionAskedCallbacks.forEach((cb) => cb(request)); + } } } diff --git a/src/types/index.ts b/src/types/index.ts index f6ae651..c03407b 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -50,6 +50,29 @@ export interface SSEEvent { properties: Record; } +export interface QuestionOption { + label: string; + description?: string; +} + +export interface QuestionItem { + question: string; + header?: string; + options?: QuestionOption[]; + multiple?: boolean; + custom?: boolean; +} + +export interface QuestionRequest { + id: string; + sessionID: string; + questions: QuestionItem[]; + tool?: { + messageID: string; + callID: string; + }; +} + export interface ServeInstance { port: number; process: ChildProcess;