Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,44 +1,40 @@
package com.google.firebase.quickstart.ai.feature.hybrid

import android.graphics.Bitmap
import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.firebase.Firebase
import com.google.firebase.ai.DownloadStatus
import com.google.firebase.ai.InferenceMode
import com.google.firebase.ai.InferenceSource
import com.google.firebase.ai.OnDeviceConfig
import com.google.firebase.ai.OnDeviceModelStatus
import com.google.firebase.ai.ai
import com.google.firebase.ai.ondevice.DownloadStatus
import com.google.firebase.ai.ondevice.FirebaseAIOnDevice
import com.google.firebase.ai.ondevice.OnDeviceModelStatus
import com.google.firebase.ai.type.GenerativeBackend
import com.google.firebase.ai.type.PublicPreviewAPI
import com.google.firebase.ai.type.content
import com.google.firebase.quickstart.ai.ui.HybridInferenceUiState
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import java.util.UUID

@Serializable
object HybridInferenceRoute

@OptIn(PublicPreviewAPI::class)
class HybridInferenceViewModel : ViewModel() {
private val _uiState = MutableStateFlow(
HybridInferenceUiState(
expenses = listOf(
Expense("Lunch", 15.50, "Example data"),
Expense("Coffee", 4.75, "Example data")
val uiState: StateFlow<HybridInferenceUiState>
field = MutableStateFlow(
HybridInferenceUiState(
expenses = listOf(
Expense("Lunch", 15.50, "Example data"),
Expense("Coffee", 4.75, "Example data")
)
)
)
Comment thread
thatfiredev marked this conversation as resolved.
)
val uiState: StateFlow<HybridInferenceUiState> = _uiState.asStateFlow()

private val model = Firebase.ai(backend = GenerativeBackend.googleAI()).generativeModel(
modelName = "gemini-3.1-flash-lite",
Expand All @@ -52,27 +48,27 @@ class HybridInferenceViewModel : ViewModel() {
private fun checkAndDownloadModel() {
viewModelScope.launch {
try {
val status = FirebaseAIOnDevice.checkStatus()
val status = model.onDeviceExtension?.checkStatus()
updateStatus(status)

if (status == OnDeviceModelStatus.DOWNLOADABLE) {
FirebaseAIOnDevice.download().collect { downloadStatus ->
model.onDeviceExtension?.download()?.collect { downloadStatus ->
when (downloadStatus) {
is DownloadStatus.DownloadStarted -> {
_uiState.update { it.copy(modelStatus = "Downloading model...") }
uiState.update { it.copy(modelStatus = "Downloading model...") }
}

is DownloadStatus.DownloadInProgress -> {
val progress = downloadStatus.totalBytesDownloaded
_uiState.update { it.copy(modelStatus = "Downloading: $progress bytes downloaded") }
uiState.update { it.copy(modelStatus = "Downloading: $progress bytes downloaded") }
}

is DownloadStatus.DownloadCompleted -> {
_uiState.update { it.copy(modelStatus = "Model ready") }
uiState.update { it.copy(modelStatus = "Model ready") }
}

is DownloadStatus.DownloadFailed -> {
_uiState.update {
uiState.update {
it.copy(
modelStatus = "Download failed", errorMessage = "Model download failed"
)
Expand All @@ -82,25 +78,24 @@ class HybridInferenceViewModel : ViewModel() {
}
}
} catch (e: Exception) {
_uiState.update { it.copy(modelStatus = "Error checking status", errorMessage = e.message) }
uiState.update { it.copy(modelStatus = "Error checking status", errorMessage = e.message) }
}
}
}

private fun updateStatus(status: OnDeviceModelStatus) {
private fun updateStatus(status: OnDeviceModelStatus?) {
val statusText = when (status) {
OnDeviceModelStatus.AVAILABLE -> "Model available"
OnDeviceModelStatus.DOWNLOADABLE -> "Model downloadable"
OnDeviceModelStatus.DOWNLOADING -> "Model downloading..."
OnDeviceModelStatus.UNAVAILABLE -> "On-device model unavailable"
else -> "Unknown"
else -> "On-device model unavailable"
}
_uiState.update { it.copy(modelStatus = statusText) }
uiState.update { it.copy(modelStatus = statusText) }
}

fun scanReceipt(bitmap: Bitmap) {
viewModelScope.launch {
_uiState.update { it.copy(isScanning = true, errorMessage = null) }
uiState.update { it.copy(isScanning = true, errorMessage = null) }
try {
val prompt = content {
image(bitmap)
Expand All @@ -124,16 +119,15 @@ class HybridInferenceViewModel : ViewModel() {
} else {
"Cloud"
}
Log.d("HybridVM", "$inferenceMode response: $text")
if (text != null) {
parseAndAddExpense(text, inferenceMode)
} else {
_uiState.update { it.copy(errorMessage = "Could not extract data") }
uiState.update { it.copy(errorMessage = "Could not extract data") }
}
} catch (e: Exception) {
_uiState.update { it.copy(errorMessage = "Error: ${e.message}") }
uiState.update { it.copy(errorMessage = "Error: ${e.message}") }
} finally {
_uiState.update { it.copy(isScanning = false) }
uiState.update { it.copy(isScanning = false) }
}
}
}
Expand All @@ -145,9 +139,9 @@ class HybridInferenceViewModel : ViewModel() {
.replace("```", "")
try {
val newExpense = Json.decodeFromString<Expense>(json).copy(inferenceMode = inferenceMode)
_uiState.update { it.copy(expenses = it.expenses + newExpense) }
uiState.update { it.copy(expenses = it.expenses + newExpense) }
} catch (e: Exception) {
_uiState.update { it.copy(errorMessage = e.localizedMessage) }
uiState.update { it.copy(errorMessage = e.localizedMessage) }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class AudioSummarizationViewModel : ChatViewModel() {
)
}
))
_messages.value = chat.history.map { UiChatMessage(it) }
_uiState.value = ChatUiState.Success
updateMessages(chat.history.map { UiChatMessage(it) })
updateUiState(ChatUiState.Success)
}

override suspend fun performSendMessage(prompt: Content, currentMessages: List<UiChatMessage>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,27 @@ import com.google.firebase.quickstart.ai.ui.ChatUiState
import com.google.firebase.quickstart.ai.ui.UiChatMessage
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.launch

@OptIn(PublicPreviewAPI::class)
abstract class ChatViewModel : ViewModel() {

protected val _uiState = MutableStateFlow<ChatUiState>(ChatUiState.Success)
val uiState: StateFlow<ChatUiState> = _uiState.asStateFlow()
val uiState: StateFlow<ChatUiState>
field = MutableStateFlow<ChatUiState>(ChatUiState.Success)

protected val _messages = MutableStateFlow<List<UiChatMessage>>(emptyList())
val messages: StateFlow<List<UiChatMessage>> = _messages.asStateFlow()
val messages: StateFlow<List<UiChatMessage>>
field = MutableStateFlow<List<UiChatMessage>>(emptyList())

protected val _attachments = MutableStateFlow<List<Attachment>>(emptyList())
val attachments: StateFlow<List<Attachment>> = _attachments.asStateFlow()
val attachments: StateFlow<List<Attachment>>
field = MutableStateFlow<List<Attachment>>(emptyList())
Comment thread
thatfiredev marked this conversation as resolved.

protected fun updateUiState(state: ChatUiState) {
uiState.value = state
}

protected fun updateMessages(list: List<UiChatMessage>) {
messages.value = list
}

abstract val initialPrompt: String

Expand All @@ -40,14 +47,14 @@ abstract class ChatViewModel : ViewModel() {
.text(userMessage)
.build()

_messages.value = _messages.value + UiChatMessage(prompt)
messages.value = messages.value + UiChatMessage(prompt)

viewModelScope.launch {
_uiState.value = ChatUiState.Loading
uiState.value = ChatUiState.Loading
try {
performSendMessage(prompt, _messages.value)
performSendMessage(prompt, messages.value)
} catch (e: Exception) {
_uiState.value = ChatUiState.Error(e.localizedMessage ?: "Unknown error")
uiState.value = ChatUiState.Error(e.localizedMessage ?: "Unknown error")
} finally {
contentBuilder = Content.Builder() // reset the builder
}
Expand Down Expand Up @@ -76,13 +83,13 @@ abstract class ChatViewModel : ViewModel() {
&& candidate.groundingMetadata?.groundingChunks?.isNotEmpty() == true
&& candidate.groundingMetadata?.searchEntryPoint == null
) {
_uiState.value = ChatUiState.Error(
uiState.value = ChatUiState.Error(
"Could not display the response because it was missing required attribution components."
)
} else {
_messages.value = currentMessages + UiChatMessage(candidate.content, candidate.groundingMetadata)
_attachments.value = emptyList()
_uiState.value = ChatUiState.Success
messages.value = currentMessages + UiChatMessage(candidate.content, candidate.groundingMetadata)
attachments.value = emptyList()
uiState.value = ChatUiState.Success
}
}

Expand All @@ -98,7 +105,7 @@ abstract class ChatViewModel : ViewModel() {
contentBuilder.inlineData(fileInBytes, mimeType ?: "text/plain")
}

_attachments.value = _attachments.value + Attachment(fileName ?: "Unnamed file")
attachments.value = attachments.value + Attachment(fileName ?: "Unnamed file")
}

protected fun decodeBitmapFromImage(input: ByteArray) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class ServerPromptTemplateViewModel : ViewModel() {
val initialPrompt = "Jane Doe"
val allowEmptyPrompt = false

private val _uiState = MutableStateFlow<ServerPromptUiState>(ServerPromptUiState.Success())
val uiState: StateFlow<ServerPromptUiState> = _uiState.asStateFlow()
val uiState: StateFlow<ServerPromptUiState>
field = MutableStateFlow<ServerPromptUiState>(ServerPromptUiState.Success())
Comment thread
thatfiredev marked this conversation as resolved.

private var templateGenerativeModel: TemplateGenerativeModel

Expand All @@ -35,13 +35,13 @@ class ServerPromptTemplateViewModel : ViewModel() {

fun generate(inputText: String) {
viewModelScope.launch {
_uiState.value = ServerPromptUiState.Loading
uiState.value = ServerPromptUiState.Loading
try {
val response = templateGenerativeModel
.generateContent("input-system-instructions", mapOf("customerName" to inputText))
_uiState.value = ServerPromptUiState.Success(response.text)
uiState.value = ServerPromptUiState.Success(response.text)
} catch (e: Exception) {
_uiState.value = ServerPromptUiState.Error(
uiState.value = ServerPromptUiState.Error(
if (e.localizedMessage?.contains("not found") == true) {
"""
Template was not found, please verify that your project contains a template
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import kotlinx.coroutines.launch
object SvgRoute

class SvgViewModel : ViewModel() {
private val _uiState = MutableStateFlow<SvgUiState>(SvgUiState.Success())
val uiState: StateFlow<SvgUiState> = _uiState.asStateFlow()
val uiState: StateFlow<SvgUiState>
field = MutableStateFlow<SvgUiState>(SvgUiState.Success())
Comment thread
thatfiredev marked this conversation as resolved.

private val generativeModel: GenerativeModel

Expand Down Expand Up @@ -53,8 +53,8 @@ class SvgViewModel : ViewModel() {
}

fun generateSVG(prompt: String) {
val currentSvgs = (_uiState.value as? SvgUiState.Success)?.svgs ?: emptyList()
_uiState.value = SvgUiState.Loading
val currentSvgs = (uiState.value as? SvgUiState.Success)?.svgs ?: emptyList()
uiState.value = SvgUiState.Loading
viewModelScope.launch(Dispatchers.IO) {
try {
val response = generativeModel.generateContent(prompt)
Expand All @@ -64,12 +64,12 @@ class SvgViewModel : ViewModel() {
?.removeSuffix("```")
?.trimIndent()
if (newSvg != null) {
_uiState.value = SvgUiState.Success(listOf(newSvg) + currentSvgs)
uiState.value = SvgUiState.Success(listOf(newSvg) + currentSvgs)
} else {
_uiState.value = SvgUiState.Success(currentSvgs)
uiState.value = SvgUiState.Success(currentSvgs)
}
} catch (e: Exception) {
_uiState.value = SvgUiState.Error(e.localizedMessage ?: "Unknown error")
uiState.value = SvgUiState.Error(e.localizedMessage ?: "Unknown error")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class TravelTipsViewModel : ChatViewModel() {
)
)

_messages.value = chat.history.map { UiChatMessage(it) }
_uiState.value = ChatUiState.Success
updateMessages(chat.history.map { UiChatMessage(it) })
updateUiState(ChatUiState.Success)
}

override suspend fun performSendMessage(prompt: Content, currentMessages: List<UiChatMessage>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class VideoSummarizationViewModel : ChatViewModel() {
}
)

_messages.value = chatHistory.map { UiChatMessage(it) }
_uiState.value = ChatUiState.Success
updateMessages(chatHistory.map { UiChatMessage(it) })
updateUiState(ChatUiState.Success)

val generativeModel = Firebase.ai(
backend = GenerativeBackend.googleAI()
Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ firebasePerf = "2.0.2"
gradleVersions = "0.54.0"
junit = "4.13.2"
junitVersion = "1.3.0"
kotlin = "2.3.21"
kotlin = "2.4.0"
kotlinxSerializationCore = "1.11.0"
lifecycle = "2.10.0"
material = "1.14.0"
Expand Down
Loading