Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion firebase-ai/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

- [feature] Added support for configuring thinking levels with Gemini 3 series
models and onwards. (#7599)
- [changed] Added `equals()` function to `GenerativeBackend`.
- [feature] Added support for [API Key
restrictions](https://docs.cloud.google.com/docs/authentication/api-keys#adding-application-restrictions) (#7679)
- [changed] Added `equals()` function to `GenerativeBackend`. (#7597)

# 17.7.0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package com.google.firebase.ai.common

import android.content.pm.PackageManager
import android.content.pm.Signature
import android.os.Build
import android.util.Log
import com.google.firebase.Firebase
import com.google.firebase.FirebaseApp
Expand Down Expand Up @@ -65,6 +68,8 @@ import io.ktor.http.contentType
import io.ktor.http.withCharset
import io.ktor.serialization.kotlinx.json.json
import io.ktor.utils.io.charsets.Charset
import java.security.MessageDigest
import java.security.NoSuchAlgorithmException
import kotlin.math.max
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds
Expand Down Expand Up @@ -138,6 +143,8 @@ internal constructor(
)

private val model = fullModelName(model)
private val appPackageName by lazy { firebaseApp.applicationContext.packageName }
private val appSigningCertFingerprint by lazy { getSigningCertFingerprint() }

private val client =
HttpClient(httpEngine) {
Expand Down Expand Up @@ -268,6 +275,8 @@ internal constructor(
contentType(ContentType.Application.Json)
header("x-goog-api-key", key)
header("x-goog-api-client", apiClient)
header("X-Android-Package", appPackageName)
header("X-Android-Cert", appSigningCertFingerprint ?: "")
if (firebaseApp.isDataCollectionDefaultEnabled) {
header("X-Firebase-AppId", googleAppId)
header("X-Firebase-AppVersion", appVersion)
Expand Down Expand Up @@ -345,6 +354,64 @@ internal constructor(
}
}

@OptIn(ExperimentalStdlibApi::class)
private fun getSigningCertFingerprint(): String? {
val signature = getCurrentSignature() ?: return null
try {
val messageDigest = MessageDigest.getInstance("SHA-1")
val digest = messageDigest.digest(signature.toByteArray())
return digest.toHexString(HexFormat.UpperCase)
} catch (e: NoSuchAlgorithmException) {
Log.w(TAG, "No support for SHA-1 algorithm found.", e)
return null
}
}

@Suppress("DEPRECATION")
private fun getCurrentSignature(): Signature? {
val packageName = firebaseApp.applicationContext.packageName
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.P) {
val packageInfo =
try {
firebaseApp.applicationContext.packageManager.getPackageInfo(
packageName,
PackageManager.GET_SIGNATURES
)
} catch (e: PackageManager.NameNotFoundException) {
Log.d(TAG, "PackageManager couldn't find the package \"$packageName\"")
return null
}
val signatures = packageInfo?.signatures ?: return null
if (signatures.size > 1) {
Log.d(
TAG,
"Multiple certificates found. On Android < P, certificate order is non-deterministic; an rotated/old cert may be used."
)
}
return signatures.firstOrNull()
}
val packageInfo =
try {
firebaseApp.applicationContext.packageManager.getPackageInfo(
packageName,
PackageManager.GET_SIGNING_CERTIFICATES
)
} catch (e: PackageManager.NameNotFoundException) {
Log.d(TAG, "PackageManager couldn't find the package \"$packageName\"")
return null
}
val signingInfo = packageInfo?.signingInfo ?: return null
if (signingInfo.hasMultipleSigners()) {
Log.d(TAG, "App has been signed with multiple certificates. Defaulting to the first one")
return signingInfo.apkContentsSigners.first()
} else {
// The `signingCertificateHistory` contains a sorted list of certificates used to sign this
// artifact, with the original one first, and once it's rotated, the current one is added at
// the end of the list. See the method's refdocs for more info.
return signingInfo.signingCertificateHistory.lastOrNull()
}
}

companion object {
private val TAG = APIController::class.java.simpleName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.firebase.ai

import androidx.test.ext.junit.runners.AndroidJUnit4
import com.google.firebase.ai.type.FinishReason
import com.google.firebase.ai.type.InvalidAPIKeyException
import com.google.firebase.ai.type.PublicPreviewAPI
Expand All @@ -36,7 +37,9 @@ import io.ktor.http.HttpStatusCode
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.withTimeout
import org.junit.Test
import org.junit.runner.RunWith

@RunWith(AndroidJUnit4::class)
internal class DevAPIUnarySnapshotTests {
private val testTimeout = 5.seconds

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@

package com.google.firebase.ai

import android.content.Context
import android.content.pm.PackageManager
import androidx.test.core.app.ApplicationProvider
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.google.firebase.FirebaseApp
import com.google.firebase.ai.common.APIController
import com.google.firebase.ai.common.JSON
import com.google.firebase.ai.common.util.doBlocking
import com.google.firebase.ai.type.Candidate
import com.google.firebase.ai.type.Content
import com.google.firebase.ai.type.CountTokensResponse
import com.google.firebase.ai.type.GenerateContentResponse
import com.google.firebase.ai.type.GenerativeBackend
import com.google.firebase.ai.type.HarmBlockMethod
Expand All @@ -41,6 +46,7 @@ import io.kotest.assertions.json.shouldContainJsonKey
import io.kotest.assertions.json.shouldContainJsonKeyValue
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.collections.shouldNotBeEmpty
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import io.kotest.matchers.types.shouldBeInstanceOf
import io.ktor.client.engine.mock.MockEngine
Expand All @@ -50,12 +56,15 @@ import io.ktor.http.HttpStatusCode
import io.ktor.http.content.TextContent
import io.ktor.http.headersOf
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.encodeToString
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito

@RunWith(AndroidJUnit4::class)
internal class GenerativeModelTesting {
private val TEST_CLIENT_ID = "test"
private val TEST_APP_ID = "1:android:12345"
Expand All @@ -65,7 +74,9 @@ internal class GenerativeModelTesting {

@Before
fun setup() {
val context = ApplicationProvider.getApplicationContext<Context>()
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)
Mockito.`when`(mockFirebaseApp.applicationContext).thenReturn(context)
}

@Test
Expand Down Expand Up @@ -112,6 +123,104 @@ internal class GenerativeModelTesting {
}
}

@Test
fun `security headers are included in request`() = doBlocking {
val mockEngine = MockEngine {
respond(
generateContentResponseAsJsonString("text response"),
HttpStatusCode.OK,
headersOf(HttpHeaders.ContentType, "application/json")
)
}
val generativeModel = generativeModelWithMockEngine(mockEngine)

withTimeout(5.seconds) { generativeModel.generateContent("my test prompt") }

val headers = mockEngine.requestHistory.first().headers
headers["X-Android-Package"] shouldBe "com.google.firebase.ai.test"
// X-Android-Cert will be empty because Robolectric doesn't provide signatures by default
headers["X-Android-Cert"] shouldBe ""
}

@Test
fun `security headers are included in streaming request`() = doBlocking {
val mockEngine = MockEngine {
respond(
generateContentResponseAsJsonString("text response"),
HttpStatusCode.OK,
headersOf(HttpHeaders.ContentType, "application/json")
)
}
val generativeModel = generativeModelWithMockEngine(mockEngine)

withTimeout(5.seconds) { generativeModel.generateContentStream("my test prompt").collect() }

val headers = mockEngine.requestHistory.first().headers
headers["X-Android-Package"] shouldBe "com.google.firebase.ai.test"
// X-Android-Cert will be empty because Robolectric doesn't provide signatures by default
headers["X-Android-Cert"] shouldBe ""
}

@Test
fun `security headers are included in countTokens request`() = doBlocking {
val mockEngine = MockEngine {
respond(
JSON.encodeToString(CountTokensResponse.Internal(totalTokens = 10)),
HttpStatusCode.OK,
headersOf(HttpHeaders.ContentType, "application/json")
)
}
val generativeModel = generativeModelWithMockEngine(mockEngine)

withTimeout(5.seconds) { generativeModel.countTokens("my test prompt") }

val headers = mockEngine.requestHistory.first().headers
headers["X-Android-Package"] shouldBe "com.google.firebase.ai.test"
// X-Android-Cert will be empty because Robolectric doesn't provide signatures by default
headers["X-Android-Cert"] shouldBe ""
}

@Test
fun `X-Android-Cert is empty when signatures are missing`() = doBlocking {
val mockEngine = MockEngine {
respond(
generateContentResponseAsJsonString("text response"),
HttpStatusCode.OK,
headersOf(HttpHeaders.ContentType, "application/json")
)
}

val mockPackageManager = Mockito.mock(PackageManager::class.java)
val mockContext = Mockito.mock(Context::class.java)
Mockito.`when`(mockContext.packageName).thenReturn("com.test.app")
Mockito.`when`(mockContext.packageManager).thenReturn(mockPackageManager)

val mockApp = Mockito.mock(FirebaseApp::class.java)
Mockito.`when`(mockApp.applicationContext).thenReturn(mockContext)

val apiController =
APIController(
"super_cool_test_key",
"gemini-2.5-flash",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
mockApp,
TEST_VERSION,
TEST_APP_ID,
null,
)

val generativeModel = GenerativeModel("gemini-2.5-flash", controller = apiController)

withTimeout(5.seconds) { generativeModel.generateContent("my test prompt") }

val headers = mockEngine.requestHistory.first().headers
headers["X-Android-Package"] shouldBe "com.test.app"
// X-Android-Cert will be empty because Robolectric doesn't provide signatures by default
headers["X-Android-Cert"] shouldBe ""
}

@Test
fun `exception thrown when using invalid location`() = doBlocking {
val mockEngine = MockEngine {
Expand Down Expand Up @@ -310,4 +419,21 @@ internal class GenerativeModelTesting {
it.shouldContainJsonKeyValue("$.generation_config.thinking_config.thinking_level", "MEDIUM")
}
}

private fun generativeModelWithMockEngine(mockEngine: MockEngine): GenerativeModel {
val apiController =
APIController(
"super_cool_test_key",
"gemini-2.5-flash",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
null,
)

return GenerativeModel("gemini-2.5-flash", controller = apiController)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

package com.google.firebase.ai.common

import android.content.Context
import androidx.test.core.app.ApplicationProvider
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.google.firebase.FirebaseApp
import com.google.firebase.ai.BuildConfig
import com.google.firebase.ai.common.util.commonTest
Expand Down Expand Up @@ -56,15 +59,16 @@ import kotlinx.serialization.json.JsonObject
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.mockito.Mockito
import org.robolectric.ParameterizedRobolectricTestRunner

private val TEST_CLIENT_ID = "genai-android/test"

private val TEST_APP_ID = "1:android:12345"

private val TEST_VERSION = 1

@RunWith(AndroidJUnit4::class)
internal class APIControllerTests {
private val testTimeout = 5.seconds

Expand Down Expand Up @@ -96,13 +100,16 @@ internal class APIControllerTests {
}

@OptIn(ExperimentalSerializationApi::class)
@RunWith(AndroidJUnit4::class)
internal class RequestFormatTests {

private val mockFirebaseApp = Mockito.mock<FirebaseApp>()

@Before
fun setup() {
val context = ApplicationProvider.getApplicationContext<Context>()
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)
Mockito.`when`(mockFirebaseApp.applicationContext).thenReturn(context)
}

@Test
Expand Down Expand Up @@ -454,13 +461,15 @@ internal class RequestFormatTests {
}
}

@RunWith(Parameterized::class)
@RunWith(ParameterizedRobolectricTestRunner::class)
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {
private val mockFirebaseApp = Mockito.mock<FirebaseApp>()

@Before
fun setup() {
val context = ApplicationProvider.getApplicationContext<Context>()
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)
Mockito.`when`(mockFirebaseApp.applicationContext).thenReturn(context)
}

@Test
Expand Down Expand Up @@ -495,7 +504,7 @@ internal class ModelNamingTests(private val modelName: String, private val actua

companion object {
@JvmStatic
@Parameterized.Parameters
@ParameterizedRobolectricTestRunner.Parameters
fun data() =
listOf(
arrayOf("gemini-pro", "models/gemini-pro"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package com.google.firebase.ai.common.util

import android.content.Context
import androidx.test.core.app.ApplicationProvider
import com.google.firebase.FirebaseApp
import com.google.firebase.ai.common.APIController
import com.google.firebase.ai.common.JSON
Expand Down Expand Up @@ -95,7 +97,9 @@ internal fun commonTest(
block: CommonTest,
) = doBlocking {
val mockFirebaseApp = Mockito.mock<FirebaseApp>()
val context = ApplicationProvider.getApplicationContext<Context>()
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)
Mockito.`when`(mockFirebaseApp.applicationContext).thenReturn(context)

val channel = ByteChannel(autoFlush = true)
val apiController =
Expand Down
Loading
Loading