Skip to content

Commit 0fb47a7

Browse files
committed
feat: refactor Local AI model management and enhance chat platform display
1 parent dc2cc40 commit 0fb47a7

17 files changed

Lines changed: 1161 additions & 495 deletions

File tree

app/src/main/AndroidManifest.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
<uses-permission android:name="android.permission.INTERNET" />
77
<uses-permission android:name="com.google.android.gms.permission.AD_ID" />
8+
<!-- Required for background model downloads -->
9+
<uses-permission android:name="android.permission.FOREGROUND_SERVICE" />
10+
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC" />
11+
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
812

913
<queries>
1014
<package android:name="com.google.android.aicore" />

app/src/main/kotlin/com/matrix/multigpt/data/database/dao/MessageDao.kt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,28 @@ interface MessageDao {
2121

2222
@Delete
2323
suspend fun deleteMessages(vararg message: Message)
24+
25+
/**
26+
* Get distinct platforms actually used in a chat (from AI responses).
27+
* Excludes null platforms (user messages).
28+
*/
29+
@Query("SELECT DISTINCT platform_type FROM messages WHERE chat_id=:chatId AND platform_type IS NOT NULL")
30+
suspend fun getUsedPlatforms(chatId: Int): List<String>
31+
32+
/**
33+
* Get all distinct platforms used across all chats.
34+
* Returns map of chatId to used platforms.
35+
*/
36+
@Query("SELECT chat_id, GROUP_CONCAT(DISTINCT platform_type) as platforms FROM messages WHERE platform_type IS NOT NULL GROUP BY chat_id")
37+
suspend fun getAllChatsUsedPlatforms(): List<ChatUsedPlatforms>
2438
}
39+
40+
/**
41+
* Data class for holding chat id and its used platforms.
42+
*/
43+
data class ChatUsedPlatforms(
44+
@androidx.room.ColumnInfo(name = "chat_id")
45+
val chatId: Int,
46+
@androidx.room.ColumnInfo(name = "platforms")
47+
val platforms: String?
48+
)

app/src/main/kotlin/com/matrix/multigpt/data/repository/ChatRepository.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package com.matrix.multigpt.data.repository
33
import com.matrix.multigpt.data.database.entity.ChatRoom
44
import com.matrix.multigpt.data.database.entity.Message
55
import com.matrix.multigpt.data.dto.ApiState
6+
import com.matrix.multigpt.data.model.ApiType
67
import kotlinx.coroutines.flow.Flow
78

89
interface ChatRepository {
@@ -19,4 +20,16 @@ interface ChatRepository {
1920
suspend fun updateChatTitle(chatRoom: ChatRoom, title: String)
2021
suspend fun saveChat(chatRoom: ChatRoom, messages: List<Message>): ChatRoom
2122
suspend fun deleteChats(chatRooms: List<ChatRoom>)
23+
24+
/**
25+
* Get the actually used platforms for a specific chat.
26+
* Returns platforms from messages that have actual AI responses.
27+
*/
28+
suspend fun getUsedPlatformsForChat(chatId: Int): List<ApiType>
29+
30+
/**
31+
* Get all actually used platforms for all chats.
32+
* Returns a map of chatId to list of used platforms.
33+
*/
34+
suspend fun getAllChatsUsedPlatforms(): Map<Int, List<ApiType>>
2235
}

app/src/main/kotlin/com/matrix/multigpt/data/repository/ChatRepositoryImpl.kt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,31 @@ class ChatRepositoryImpl @Inject constructor(
266266
chatRoomDao.deleteChatRooms(*chatRooms.toTypedArray())
267267
}
268268

269+
override suspend fun getUsedPlatformsForChat(chatId: Int): List<ApiType> {
270+
val platformStrings = messageDao.getUsedPlatforms(chatId)
271+
return platformStrings.mapNotNull { platformStr ->
272+
try {
273+
ApiType.valueOf(platformStr)
274+
} catch (e: Exception) {
275+
null
276+
}
277+
}
278+
}
279+
280+
override suspend fun getAllChatsUsedPlatforms(): Map<Int, List<ApiType>> {
281+
val chatPlatforms = messageDao.getAllChatsUsedPlatforms()
282+
return chatPlatforms.associate { chatUsedPlatforms ->
283+
val platforms = chatUsedPlatforms.platforms?.split(",")?.mapNotNull { platformStr ->
284+
try {
285+
ApiType.valueOf(platformStr.trim())
286+
} catch (e: Exception) {
287+
null
288+
}
289+
} ?: emptyList()
290+
chatUsedPlatforms.chatId to platforms
291+
}
292+
}
293+
269294
private fun messageToOpenAICompatibleMessage(apiType: ApiType, messages: List<Message>): List<ChatMessage> {
270295
val result = mutableListOf<ChatMessage>()
271296

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package com.matrix.multigpt.data.source
2+
3+
import kotlinx.coroutines.flow.MutableStateFlow
4+
import kotlinx.coroutines.flow.StateFlow
5+
import kotlinx.coroutines.flow.asStateFlow
6+
import javax.inject.Inject
7+
import javax.inject.Singleton
8+
9+
/**
10+
* Application-scoped singleton to track download state.
11+
* Persists across navigation because it's a Singleton.
12+
*/
13+
@Singleton
14+
class LocalModelDownloadState @Inject constructor() {
15+
16+
private val _downloadingModels = MutableStateFlow<Set<String>>(emptySet())
17+
val downloadingModels: StateFlow<Set<String>> = _downloadingModels.asStateFlow()
18+
19+
private val _downloadProgressMap = MutableStateFlow<Map<String, Float>>(emptyMap())
20+
val downloadProgressMap: StateFlow<Map<String, Float>> = _downloadProgressMap.asStateFlow()
21+
22+
private val _downloadedBytesMap = MutableStateFlow<Map<String, Long>>(emptyMap())
23+
val downloadedBytesMap: StateFlow<Map<String, Long>> = _downloadedBytesMap.asStateFlow()
24+
25+
private val _totalBytesMap = MutableStateFlow<Map<String, Long>>(emptyMap())
26+
val totalBytesMap: StateFlow<Map<String, Long>> = _totalBytesMap.asStateFlow()
27+
28+
fun startDownload(modelId: String, totalSize: Long) {
29+
_downloadingModels.value = _downloadingModels.value + modelId
30+
_downloadProgressMap.value = _downloadProgressMap.value + (modelId to 0f)
31+
_totalBytesMap.value = _totalBytesMap.value + (modelId to totalSize)
32+
_downloadedBytesMap.value = _downloadedBytesMap.value + (modelId to 0L)
33+
}
34+
35+
fun updateProgress(modelId: String, progress: Float, downloadedBytes: Long, totalBytes: Long) {
36+
_downloadProgressMap.value = _downloadProgressMap.value + (modelId to progress)
37+
_downloadedBytesMap.value = _downloadedBytesMap.value + (modelId to downloadedBytes)
38+
if (totalBytes > 0) {
39+
_totalBytesMap.value = _totalBytesMap.value + (modelId to totalBytes)
40+
}
41+
}
42+
43+
fun completeDownload(modelId: String) {
44+
_downloadingModels.value = _downloadingModels.value - modelId
45+
_downloadProgressMap.value = _downloadProgressMap.value - modelId
46+
_downloadedBytesMap.value = _downloadedBytesMap.value - modelId
47+
_totalBytesMap.value = _totalBytesMap.value - modelId
48+
}
49+
50+
fun isDownloading(modelId: String): Boolean {
51+
return _downloadingModels.value.contains(modelId)
52+
}
53+
54+
fun getProgress(modelId: String): Float {
55+
return _downloadProgressMap.value[modelId] ?: 0f
56+
}
57+
}

app/src/main/kotlin/com/matrix/multigpt/presentation/ui/chat/ChatScreen.kt

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,12 @@ fun ChatScreen(
325325
.animateItem()
326326
) {
327327
Spacer(modifier = Modifier.width(8.dp))
328-
chatViewModel.enabledPlatformsInChat.sorted().forEach { apiType ->
329-
val message = when (apiType) {
328+
// Only show the active provider (the one currently selected in currentModels)
329+
val activeApiType = currentModels.keys.firstOrNull()
330+
?: chatViewModel.enabledPlatformsInChat.firstOrNull()
331+
332+
if (activeApiType != null) {
333+
val message = when (activeApiType) {
330334
ApiType.OPENAI -> openAIMessage
331335
ApiType.ANTHROPIC -> anthropicMessage
332336
ApiType.GOOGLE -> googleMessage
@@ -336,7 +340,7 @@ fun ChatScreen(
336340
ApiType.LOCAL -> localMessage
337341
}
338342

339-
val loadingState = when (apiType) {
343+
val loadingState = when (activeApiType) {
340344
ApiType.OPENAI -> openaiLoadingState
341345
ApiType.ANTHROPIC -> anthropicLoadingState
342346
ApiType.GOOGLE -> googleLoadingState
@@ -350,7 +354,7 @@ fun ChatScreen(
350354
if (loadingState == ChatViewModel.LoadingState.Loading && message.content.isEmpty()) {
351355
TypingIndicator(
352356
modifier = Modifier.padding(horizontal = 8.dp, vertical = 12.dp),
353-
apiType = apiType
357+
apiType = activeApiType
354358
)
355359
} else {
356360
OpponentChatBubble(
@@ -359,9 +363,9 @@ fun ChatScreen(
359363
.widthIn(max = maximumChatBubbleWidth),
360364
canRetry = canUseChat,
361365
isLoading = loadingState == ChatViewModel.LoadingState.Loading,
362-
apiType = apiType,
366+
apiType = activeApiType,
363367
text = message.content,
364-
modelName = currentModels[apiType],
368+
modelName = currentModels[activeApiType],
365369
onCopyClick = {
366370
haptic.performHapticFeedback(HapticFeedbackType.LongPress)
367371
clipboardManager.setText(AnnotatedString(message.content.trim()))

app/src/main/kotlin/com/matrix/multigpt/presentation/ui/home/HomeScreen.kt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,19 @@ fun HomeScreen(
206206
) {
207207
item { ChatsTitle(scrollBehavior) }
208208
itemsIndexed(chatListState.chats, key = { _, it -> it.id }) { idx, chatRoom ->
209-
val usingPlatform = chatRoom.enabledPlatform.joinToString(", ") { platformTitles[it] ?: "" }
209+
// Use actually used platforms from messages, fall back to enabled platforms if empty
210+
val actualUsedPlatforms = chatListState.usedPlatformsMap[chatRoom.id] ?: emptyList()
211+
val platformsToShow = if (actualUsedPlatforms.isNotEmpty()) actualUsedPlatforms else chatRoom.enabledPlatform
212+
val usingPlatform = when {
213+
platformsToShow.size <= 2 -> {
214+
platformsToShow.joinToString(", ") { platformTitles[it] ?: it.name }
215+
}
216+
else -> {
217+
val first = platformTitles[platformsToShow.first()] ?: platformsToShow.first().name
218+
val remaining = platformsToShow.size - 1
219+
"$first and $remaining more"
220+
}
221+
}
210222
ListItem(
211223
modifier = Modifier
212224
.fillMaxWidth()

app/src/main/kotlin/com/matrix/multigpt/presentation/ui/home/HomeViewModel.kt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ class HomeViewModel @Inject constructor(
3131
data class ChatListState(
3232
val chats: List<ChatRoom> = listOf(),
3333
val isSelectionMode: Boolean = false,
34-
val selected: List<Boolean> = listOf()
34+
val selected: List<Boolean> = listOf(),
35+
val usedPlatformsMap: Map<Int, List<ApiType>> = emptyMap() // Actually used platforms per chat
3536
)
3637

3738
private val _chatListState = MutableStateFlow(ChatListState())
@@ -123,16 +124,21 @@ class HomeViewModel @Inject constructor(
123124
fun fetchChats() {
124125
viewModelScope.launch {
125126
val chats = chatRepository.fetchChatList()
127+
128+
// Fetch actually used platforms for all chats
129+
val usedPlatformsMap = chatRepository.getAllChatsUsedPlatforms()
126130

127131
_chatListState.update {
128132
it.copy(
129133
chats = chats,
130134
selected = List(chats.size) { false },
131-
isSelectionMode = false
135+
isSelectionMode = false,
136+
usedPlatformsMap = usedPlatformsMap
132137
)
133138
}
134139

135140
Log.d("chats", "${_chatListState.value.chats}")
141+
Log.d("HomeViewModel", "Used platforms per chat: $usedPlatformsMap")
136142
}
137143
}
138144

0 commit comments

Comments
 (0)