diff --git a/apps/mobile_chat_app/lib/features/chat/chat_history_api_service.dart b/apps/mobile_chat_app/lib/features/chat/chat_history_api_service.dart index eb67446..19426b5 100644 --- a/apps/mobile_chat_app/lib/features/chat/chat_history_api_service.dart +++ b/apps/mobile_chat_app/lib/features/chat/chat_history_api_service.dart @@ -65,6 +65,18 @@ class ChatAcceptedTask { final String acceptedAt; } +class ChatForkResult { + const ChatForkResult({ + required this.threadId, + required this.channelId, + required this.forkedSessionId, + }); + + final String threadId; + final String channelId; + final String forkedSessionId; +} + class ChatChannelNameSetting { const ChatChannelNameSetting({ required this.channelId, @@ -134,6 +146,7 @@ class ChatHistoryApiService { Uri get _scopesUri => Uri.parse('$_base/api/chat/scopes'); Uri get _scopeSettingsUri => Uri.parse('$_base/api/chat/scope-settings'); Uri get _channelNamesUri => Uri.parse('$_base/api/chat/channel-names'); + Uri get _forkUri => Uri.parse('$_base/api/chat/fork'); ChatTaskState? _parseTaskState(Object? value) { if (value is! String || value.isEmpty) return null; @@ -521,6 +534,32 @@ class ChatHistoryApiService { .toList(growable: false); } + Future forkThread({ + required String parentSessionId, + required String forkMessageId, + required String newThreadId, + }) async { + final response = await _apiClient.post( + _forkUri, + headers: {'Content-Type': 'application/json'}, + body: jsonEncode({ + 'parentSessionId': parentSessionId, + 'forkMessageId': forkMessageId, + 'newThreadId': newThreadId, + }), + ); + if (response.statusCode != 200) { + throw Exception('Failed to fork thread (${response.statusCode})'); + } + final raw = jsonDecode(response.body); + if (raw is! Map) throw Exception('Invalid fork response'); + return ChatForkResult( + threadId: raw['threadId'] as String, + channelId: raw['channelId'] as String, + forkedSessionId: raw['forkedSessionId'] as String, + ); + } + Future upsertMessages({ required List messages, }) async { diff --git a/apps/mobile_chat_app/lib/features/chat/chat_screen.dart b/apps/mobile_chat_app/lib/features/chat/chat_screen.dart index 6f2a704..bdbd3b9 100644 --- a/apps/mobile_chat_app/lib/features/chat/chat_screen.dart +++ b/apps/mobile_chat_app/lib/features/chat/chat_screen.dart @@ -46,6 +46,7 @@ class ChatScreen extends StatefulWidget { class _ChatScreenState extends State { final List _messages = []; + final Set _archivedMessageIds = {}; bool _isSending = false; bool _isStreaming = false; bool _loadingAgents = true; @@ -847,6 +848,7 @@ class _ChatScreenState extends State { _activeChannelId = id; _activeSubSection = 'main'; _messages.clear(); + _archivedMessageIds.clear(); _latestCheckpointCursor = null; _lastSyncedSeq = 0; }); @@ -1055,6 +1057,7 @@ class _ChatScreenState extends State { _activeChannelId = resolvedChannelId; _activeSubSection = restoredSubSection; _messages.clear(); + _archivedMessageIds.clear(); _latestCheckpointCursor = null; _lastSyncedSeq = 0; }); @@ -1133,6 +1136,7 @@ class _ChatScreenState extends State { if (!mounted || _isScopeStale()) return; setState(() { _messages.clear(); + _archivedMessageIds.clear(); _highlights = const {}; _textHighlights = const []; _latestCheckpointCursor = null; @@ -1228,6 +1232,89 @@ class _ChatScreenState extends State { } } + void _handleArchiveRound(ChatMessage message) { + final messageId = message.messageId; + if (messageId == null) return; + // Find the user message that immediately precedes this assistant message. + String? precedingUserMessageId; + bool foundTarget = false; + for (int i = _messages.length - 1; i >= 0; i--) { + final msg = _messages[i]; + if (!foundTarget) { + if (msg.messageId == messageId) foundTarget = true; + } else if (msg.role == 'user') { + precedingUserMessageId = msg.messageId; + break; + } + } + setState(() { + _archivedMessageIds.add(messageId); + if (precedingUserMessageId != null) { + _archivedMessageIds.add(precedingUserMessageId); + } + }); + } + + void _handleArchiveReply(ChatMessage message) { + final messageId = message.messageId; + if (messageId == null) return; + setState(() { + _archivedMessageIds.add(messageId); + }); + } + + Future _handleFork(ChatMessage message) async { + final messageId = message.messageId; + if (messageId == null) return; + final parentSessionId = _sessionIdForScope; + final newThreadId = _newId('fork'); + try { + await _chatHistoryApiService.forkThread( + parentSessionId: parentSessionId, + forkMessageId: messageId, + newThreadId: newThreadId, + ); + } catch (e) { + if (!mounted) return; + ScaffoldMessenger.of(context).showSnackBar( + SnackBar(content: Text('Fork failed: $e')), + ); + return; + } + if (!mounted) return; + final section = ChatSubSection( + id: newThreadId, + parentChannelId: _activeChannelId, + name: _timestampName(prefix: 'fork'), + createdAt: DateTime.now(), + ); + setState(() { + final items = _channelSubSections.putIfAbsent( + _activeChannelId, + () => [], + ); + items.add(section); + _activeSubSection = newThreadId; + _lastActiveSubSectionByChannel[_activeChannelId] = newThreadId; + _messages.clear(); + _archivedMessageIds.clear(); + _latestCheckpointCursor = null; + _lastSyncedSeq = 0; + }); + _configureActiveScopeSync(); + } + + Future _handleBranch(ChatMessage message) async { + // Branch forks from the user message itself — the new thread inherits + // context up to and including this user message from the parent. + await _handleFork(message); + } + + void _handleResend(ChatMessage message) { + if (message.content.trim().isEmpty) return; + _sendMessage(message.content); + } + String _subSectionKey(String channelId, String sectionId) => '$channelId::$sectionId'; @@ -2221,6 +2308,7 @@ class _ChatScreenState extends State { _activeSubSection = id; _lastActiveSubSectionByChannel[_activeChannelId] = id; _messages.clear(); + _archivedMessageIds.clear(); _latestCheckpointCursor = null; _lastSyncedSeq = 0; }); @@ -2298,6 +2386,7 @@ class _ChatScreenState extends State { _lastActiveSubSectionByChannel[channelId] = 'main'; _activeSubSection = 'main'; _messages.clear(); + _archivedMessageIds.clear(); _latestCheckpointCursor = null; _lastSyncedSeq = 0; }); @@ -2324,6 +2413,7 @@ class _ChatScreenState extends State { _activeSubSection = subSectionId; _lastActiveSubSectionByChannel[_activeChannelId] = subSectionId; _messages.clear(); + _archivedMessageIds.clear(); _latestCheckpointCursor = null; _lastSyncedSeq = 0; }); @@ -2950,10 +3040,21 @@ class _ChatScreenState extends State { children: [ Expanded( child: MessageList( - messages: _messages, + messages: _archivedMessageIds.isEmpty + ? _messages + : _messages + .where((m) => + m.messageId == null || + !_archivedMessageIds.contains(m.messageId)) + .toList(), highlights: _highlights, onHighlight: _handleHighlight, onDeleteHighlight: _handleDeleteHighlight, + onArchiveRound: _handleArchiveRound, + onArchiveReply: _handleArchiveReply, + onFork: _handleFork, + onBranch: _handleBranch, + onResend: _handleResend, ), ), Builder( diff --git a/apps/mobile_chat_app/lib/features/chat/widgets/message_list.dart b/apps/mobile_chat_app/lib/features/chat/widgets/message_list.dart index 8aff630..9602647 100644 --- a/apps/mobile_chat_app/lib/features/chat/widgets/message_list.dart +++ b/apps/mobile_chat_app/lib/features/chat/widgets/message_list.dart @@ -46,7 +46,9 @@ class MessageList extends StatefulWidget { this.onDeleteHighlight, this.onArchiveRound, this.onArchiveReply, - this.onMoveToThread, + this.onFork, + this.onBranch, + this.onResend, }); final List messages; @@ -66,14 +68,20 @@ class MessageList extends StatefulWidget { /// Called when the user taps Remove highlight in the floating highlight menu. final void Function(String highlightId)? onDeleteHighlight; - /// Called when the user selects "归档此轮" from the assistant message menu. + /// Called when the user selects "Archive Round" from the assistant message menu. final void Function(ChatMessage message)? onArchiveRound; - /// Called when the user selects "归档此回复" from the assistant message menu. + /// Called when the user selects "Archive Reply" from the assistant message menu. final void Function(ChatMessage message)? onArchiveReply; - /// Called when the user selects "移入Thread" from the assistant message menu. - final void Function(ChatMessage message)? onMoveToThread; + /// Called when the user selects "Fork" from the assistant message menu. + final void Function(ChatMessage message)? onFork; + + /// Called when the user selects "Branch" from the user message context menu. + final void Function(ChatMessage message)? onBranch; + + /// Called when the user selects "Resend" from the user message context menu. + final void Function(ChatMessage message)? onResend; @override State createState() => _MessageListState(); @@ -439,10 +447,10 @@ class _MessageListState extends State { ).showSnackBar(const SnackBar(content: Text('Copied'))); break; case 'branch': + widget.onBranch?.call(message); + break; case 'resend': - ScaffoldMessenger.of( - context, - ).showSnackBar(const SnackBar(content: Text('Coming soon'))); + widget.onResend?.call(message); break; } } @@ -454,8 +462,8 @@ class _MessageListState extends State { }) async { final hasArchiveRound = widget.onArchiveRound != null; final hasArchiveReply = widget.onArchiveReply != null; - final hasMoveToThread = widget.onMoveToThread != null; - if (!hasArchiveRound && !hasArchiveReply && !hasMoveToThread) return; + final hasFork = widget.onFork != null; + if (!hasArchiveRound && !hasArchiveReply && !hasFork) return; final overlay = Overlay.of(context).context.findRenderObject() as RenderBox; final result = await showGeneralDialog( context: context, @@ -468,7 +476,7 @@ class _MessageListState extends State { screenSize: overlay.size, showArchiveRound: hasArchiveRound, showArchiveReply: hasArchiveReply, - showMoveToThread: hasMoveToThread, + showFork: hasFork, ), ); if (!context.mounted || result == null) return; @@ -479,8 +487,8 @@ class _MessageListState extends State { case 'archive_reply': widget.onArchiveReply?.call(message); break; - case 'move_to_thread': - widget.onMoveToThread?.call(message); + case 'fork': + widget.onFork?.call(message); break; } } @@ -2670,8 +2678,8 @@ class _UserMessageContextMenu extends StatelessWidget { crossAxisAlignment: CrossAxisAlignment.stretch, children: [ _MenuItem(label: 'Copy', value: 'copy'), - _MenuItem(label: 'Branch (coming soon)', value: 'branch'), - _MenuItem(label: 'Resend (coming soon)', value: 'resend'), + _MenuItem(label: 'Branch', value: 'branch'), + _MenuItem(label: 'Resend', value: 'resend'), Container( padding: const EdgeInsets.symmetric( horizontal: 16, @@ -2739,14 +2747,14 @@ class _AssistantMessageActionMenu extends StatelessWidget { required this.screenSize, required this.showArchiveRound, required this.showArchiveReply, - required this.showMoveToThread, + required this.showFork, }); final Offset position; final Size screenSize; final bool showArchiveRound; final bool showArchiveReply; - final bool showMoveToThread; + final bool showFork; static const double _menuWidth = 200.0; static const double _itemHeight = 48.0; @@ -2755,9 +2763,9 @@ class _AssistantMessageActionMenu extends StatelessWidget { @override Widget build(BuildContext context) { final items = [ - if (showArchiveRound) const _MenuItem(label: '归档此轮', value: 'archive_round'), - if (showArchiveReply) const _MenuItem(label: '归档此回复', value: 'archive_reply'), - if (showMoveToThread) const _MenuItem(label: '移入Thread', value: 'move_to_thread'), + if (showArchiveRound) const _MenuItem(label: 'Archive Round', value: 'archive_round'), + if (showArchiveReply) const _MenuItem(label: 'Archive Reply', value: 'archive_reply'), + if (showFork) const _MenuItem(label: 'Fork', value: 'fork'), ]; final menuHeight = _itemHeight * items.length; diff --git a/apps/mobile_chat_app/test/message_list_test.dart b/apps/mobile_chat_app/test/message_list_test.dart index 1cc57e2..22082e7 100644 --- a/apps/mobile_chat_app/test/message_list_test.dart +++ b/apps/mobile_chat_app/test/message_list_test.dart @@ -40,7 +40,7 @@ Widget _build( void Function(String)? onDeleteHighlight, void Function(ChatMessage)? onArchiveRound, void Function(ChatMessage)? onArchiveReply, - void Function(ChatMessage)? onMoveToThread, + void Function(ChatMessage)? onFork, }) => MaterialApp( theme: theme, @@ -54,7 +54,7 @@ Widget _build( onDeleteHighlight: onDeleteHighlight, onArchiveRound: onArchiveRound, onArchiveReply: onArchiveReply, - onMoveToThread: onMoveToThread, + onFork: onFork, ), ), ), @@ -1637,13 +1637,13 @@ void main() { expect(received?.messageId, 'a-action'); }); - testWidgets('tapping move_to_thread calls onMoveToThread with message', + testWidgets('tapping fork calls onFork with message', (tester) async { ChatMessage? received; await tester.pumpWidget( _build( [_assistantMsg()], - onMoveToThread: (m) => received = m, + onFork: (m) => received = m, ), ); await tester.pumpAndSettle(); @@ -1651,8 +1651,8 @@ void main() { await tester.tap(find.byIcon(Icons.more_horiz)); await tester.pumpAndSettle(); - expect(find.text('移入Thread'), findsOneWidget); - await tester.tap(find.text('移入Thread')); + expect(find.text('Fork'), findsOneWidget); + await tester.tap(find.text('Fork')); await tester.pumpAndSettle(); expect(received?.messageId, 'a-action'); diff --git a/apps/node_backend/src/db/migrations/023_create_chat_thread_forks.sql b/apps/node_backend/src/db/migrations/023_create_chat_thread_forks.sql new file mode 100644 index 0000000..ea2fd45 --- /dev/null +++ b/apps/node_backend/src/db/migrations/023_create_chat_thread_forks.sql @@ -0,0 +1,22 @@ +-- Migration: Create chat_thread_forks table +-- Description: Tracks thread fork relationships so that forked threads can +-- inherit context from a parent thread up to a specific message. +-- When assembling LLM context, messages from the parent session up to +-- (and including) the fork point are prepended before the forked session's +-- own messages. +-- Version: 023 +-- Date: 2026-06-03 + +CREATE TABLE IF NOT EXISTS chat_thread_forks ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + forked_session_id VARCHAR(255) NOT NULL, + parent_session_id VARCHAR(255) NOT NULL, + fork_message_id VARCHAR(255) NOT NULL, + fork_write_seq BIGINT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_id, forked_session_id) +); + +CREATE INDEX IF NOT EXISTS idx_chat_thread_forks_forked_session + ON chat_thread_forks(user_id, forked_session_id); diff --git a/apps/node_backend/src/routes/chat.ts b/apps/node_backend/src/routes/chat.ts index 1a6dbdc..b92544e 100644 --- a/apps/node_backend/src/routes/chat.ts +++ b/apps/node_backend/src/routes/chat.ts @@ -3,14 +3,17 @@ import rateLimit from "express-rate-limit"; import { authenticate, AuthRequest } from "../middleware/auth.js"; import { acceptTask, + forkThread, listUserScopes, listSessionMessagesForModel, syncMessages, upsertMessages, type AcceptTaskInput, + type ForkThreadInput, type MessageUpsertInput, } from "../services/chatAsyncTransportService.js"; import { + buildChatSessionId, builtinDefaultNodeRef, CHAT_ROUTER_LOCAL, CHAT_ROUTER_PLUGIN, @@ -1832,6 +1835,60 @@ router.put("/channel-names", async (req: AuthRequest, res: Response) => { } }); +router.post("/fork", async (req: AuthRequest, res: Response) => { + try { + const userId = req.userId; + if (!userId) { + res.status(401).json({ error: "Unauthorized" }); + return; + } + + const body = req.body ?? {}; + const parentSessionId = parseSessionId(body.parentSessionId); + const forkMessageId = parseSessionId(body.forkMessageId); + const newThreadId = parseSessionId(body.newThreadId); + + if (!parentSessionId || !forkMessageId || !newThreadId) { + res.status(400).json({ + error: "parentSessionId, forkMessageId, and newThreadId are required", + }); + return; + } + + // Derive channelId from parentSessionId (format: session:channelId:threadId) + const parts = parentSessionId.split(':'); + if (parts.length < 3 || parts[0] !== 'session') { + res.status(400).json({ error: "Invalid parentSessionId format" }); + return; + } + const channelId = parts[1]; + const forkedSessionId = buildChatSessionId(channelId, newThreadId); + + const input: ForkThreadInput = { + userId, + forkedSessionId, + parentSessionId, + forkMessageId, + }; + const result = await forkThread(input); + + res.json({ + threadId: newThreadId, + channelId, + forkedSessionId: result.forkedSessionId, + parentSessionId: result.parentSessionId, + forkWriteSeq: result.forkWriteSeq, + }); + } catch (error) { + console.error("Fork thread error:", error); + if (error instanceof Error && error.message.startsWith("Fork message not found")) { + res.status(404).json({ error: error.message }); + return; + } + res.status(500).json({ error: "Internal server error" }); + } +}); + router.put("/scope-settings", async (req: AuthRequest, res: Response) => { try { const userId = req.userId; diff --git a/apps/node_backend/src/services/chatAsyncTransportService.ts b/apps/node_backend/src/services/chatAsyncTransportService.ts index a955ea1..f0cad1c 100644 --- a/apps/node_backend/src/services/chatAsyncTransportService.ts +++ b/apps/node_backend/src/services/chatAsyncTransportService.ts @@ -309,13 +309,68 @@ export async function syncMessages( } +interface ThreadForkRow { + parent_session_id: string; + fork_write_seq: string; +} + +/** + * Look up the fork record for a session, if one exists. + * Returns null when the session is not a fork. + */ +async function getThreadFork( + userId: string, + sessionId: string, +): Promise<{ parentSessionId: string; forkWriteSeq: bigint } | null> { + const result = await pool.query( + `SELECT parent_session_id, fork_write_seq + FROM chat_thread_forks + WHERE user_id = $1 + AND forked_session_id = $2 + LIMIT 1`, + [userId, sessionId], + ); + if (result.rows.length === 0) return null; + const row = result.rows[0]; + return { + parentSessionId: row.parent_session_id, + forkWriteSeq: BigInt(row.fork_write_seq), + }; +} + +/** + * Collect messages from rows into the budget, oldest-first. + * rows must be supplied newest-first (ORDER BY write_seq DESC). + */ +function collectMessages( + rows: ChatMessageRow[], + budget: { used: number; maxChars: number }, +): Array<{ role: 'user' | 'assistant'; content: string }> { + const collected: Array<{ role: 'user' | 'assistant'; content: string }> = []; + for (const row of rows) { + const content = row.content?.trim() ?? ''; + if (!content) continue; + if (budget.used + content.length > budget.maxChars) break; + budget.used += content.length; + collected.push({ role: row.role as 'user' | 'assistant', content }); + } + return collected.reverse(); +} + export async function listSessionMessagesForModel( userId: string, sessionId: string, options: { limit?: number; maxChars?: number } = {}, ): Promise> { const limit = Math.max(1, Math.min(options.limit ?? 40, 200)); - const result = await pool.query( + const maxChars = Math.max(200, Math.min(options.maxChars ?? 8000, 64000)); + const budget = { used: 0, maxChars }; + + // Check if this session is a fork; if so, prepend parent context first. + const fork = await getThreadFork(userId, sessionId); + + // Fetch own messages (newest-first so we can apply the char budget). + const ownResult = await pool.query( `SELECT seq_id, write_seq, message_id, task_id, channel_id, session_id, thread_id, role, content, task_state, checkpoint_cursor, metadata, created_at, updated_at FROM chat_messages @@ -327,24 +382,78 @@ export async function listSessionMessagesForModel( [userId, sessionId, limit], ); - // result.rows is already newest-first (ORDER BY write_seq DESC). - // Collect messages greedily from newest to oldest so that the most recent - // turns are always included; stop as soon as adding the next message would - // exceed the budget. Reverse at the end to restore chronological order. - const maxChars = Math.max(200, Math.min(options.maxChars ?? 8000, 64000)); - const collected: Array<{ role: 'user' | 'assistant'; content: string }> = []; - let used = 0; - for (const row of result.rows) { - const content = row.content?.trim() ?? ''; - if (!content) continue; - if (used + content.length > maxChars) break; - used += content.length; - collected.push({ - role: row.role as 'user' | 'assistant', - content, - }); + // Collect the forked session's own messages against the shared budget. + const ownMessages = collectMessages(ownResult.rows, budget); + + if (!fork) { + return ownMessages; } - return collected.reverse(); + + // Fetch parent messages up to (and including) the fork point. + const parentResult = await pool.query( + `SELECT seq_id, write_seq, message_id, task_id, channel_id, session_id, thread_id, + role, content, task_state, checkpoint_cursor, metadata, created_at, updated_at + FROM chat_messages + WHERE user_id = $1 + AND session_id = $2 + AND role IN ('user', 'assistant') + AND write_seq <= $3 + ORDER BY write_seq DESC + LIMIT $4`, + [userId, fork.parentSessionId, fork.forkWriteSeq, limit], + ); + + const parentMessages = collectMessages(parentResult.rows, budget); + + // Return parent context first (chronological), then the fork's own messages. + return [...parentMessages, ...ownMessages]; +} + +export interface ForkThreadInput { + userId: string; + forkedSessionId: string; + parentSessionId: string; + forkMessageId: string; +} + +export interface ForkThreadResult { + forkedSessionId: string; + parentSessionId: string; + forkMessageId: string; + forkWriteSeq: number; +} + +/** + * Record a thread fork in the database. + * The fork point is identified by the parent message_id; its write_seq is + * looked up at insert time so that context assembly can use a stable cursor. + */ +export async function forkThread(input: ForkThreadInput): Promise { + const { userId, forkedSessionId, parentSessionId, forkMessageId } = input; + + // Look up the write_seq for the fork point message. + const msgResult = await pool.query<{ write_seq: string }>( + `SELECT write_seq FROM chat_messages + WHERE user_id = $1 + AND session_id = $2 + AND message_id = $3 + LIMIT 1`, + [userId, parentSessionId, forkMessageId], + ); + if (msgResult.rows.length === 0) { + throw new Error(`Fork message not found: ${forkMessageId}`); + } + const forkWriteSeq = Number(msgResult.rows[0].write_seq); + + await pool.query( + `INSERT INTO chat_thread_forks + (user_id, forked_session_id, parent_session_id, fork_message_id, fork_write_seq) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (user_id, forked_session_id) DO NOTHING`, + [userId, forkedSessionId, parentSessionId, forkMessageId, forkWriteSeq], + ); + + return { forkedSessionId, parentSessionId, forkMessageId, forkWriteSeq }; } export async function listUserScopes(userId: string): Promise {