Add LlmGenerationConfig, error handling, and callback contract tests#19714
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19714
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 171fb1b with merge base b4a9e72 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@psiddh has exported this pull request. If you are a Meta employee, you can view the originating Diff in D105886676. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Adds a new Android instrumentation test suite to exercise the public LlmGenerationConfig builder API when calling LlmModule.generate(prompt, config, callback), and to validate key error paths and callback “contract” behaviors for the LLM Android wrapper.
Changes:
- Introduces
LlmGenerationConfigTestcovering config-driven generation (default config, seqLen, echo). - Adds tests for error handling paths (invalid model path, empty prompt behavior, generate-after-close).
- Adds callback contract assertions (onResult frequency bounds, onStats JSON schema, callback ordering).
Comments suppressed due to low confidence (3)
extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmGenerationConfigTest.kt:95
testEchoModeTrue/testEchoModeFalseassert on generated content (e.g., output contains "Hello" / doesn’t start with the prompt). Other LLM instrumentation tests in this suite explicitly avoid asserting content for the TinyStories fixture because it’s not instruction-tuned and content is not stable. These checks are also vulnerable to false failures when the model happens to generate the same prefix as the prompt even with echo disabled. Consider asserting a structural invariant instead (e.g., compare echo=true vs echo=false outputs from a reset context and verify echo=true output is the prompt prefix + echo=false output).
@Test(timeout = MAX_TEST_TIMEOUT_MS)
fun testEchoModeTrue() {
val config = buildConfig(echo = true)
val callback = CollectingCallback()
llmModule.generate(TEST_PROMPT, config, callback)
val output = callback.results.joinToString("")
assertTrue(
"Echo mode should include prompt tokens",
output.contains("Hello") || output.contains("hello"),
)
}
@Test(timeout = MAX_TEST_TIMEOUT_MS)
fun testEchoModeFalse() {
val config = buildConfig(echo = false)
val callback = CollectingCallback()
llmModule.generate(TEST_PROMPT, config, callback)
assertTrue("Should produce output", callback.results.isNotEmpty())
val output = callback.results.joinToString("")
assertFalse(
"With echo=false, output should NOT start with prompt text",
output.startsWith(TEST_PROMPT),
)
}
extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmGenerationConfigTest.kt:107
- The comment that
onResultcallbacks are “1:1 with tokens” is inaccurate: the JNI layer buffers tokens until it has valid UTF-8 and may coalesce multiple tokens into a singleonResult(seeextension/android/jni/jni_layer_llama.cpptoken_buffer logic). Usingresults.sizeas a token count can therefore be misleading; consider rewording the comment and/or validating seqLen via stats (e.g.,generated_tokens) instead of callback count.
// Note: results.size counts onResult callbacks, which is 1:1 with tokens for LlmModule
assertTrue("Should produce at least 1 token", callback.results.isNotEmpty())
assertTrue(
"Token count (${callback.results.size}) should be <= seqLen ($shortSeqLen)",
callback.results.size <= shortSeqLen,
extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmGenerationConfigTest.kt:240
- This file defines a custom
assertThrowsthat catches allThrowable. The project already uses JUnit’sorg.junit.Assert.assertThrowsin other Android tests, which avoids swallowingErrors and keeps stack traces consistent. Consider switching to the standard JUnit helper and deleting the custom implementation.
private fun assertThrows(exClass: Class<out Throwable>, block: () -> Unit) {
try {
block()
fail("Expected ${exClass.simpleName} but no exception was thrown")
} catch (e: Throwable) {
assertTrue(
"Expected ${exClass.simpleName} but got ${e.javaClass.simpleName}: ${e.message}",
exClass.isInstance(e),
)
}
}
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Summary: Tests the production LlmGenerationConfig builder API, error paths (invalid model, empty prompt, use-after-close), and callback contract (ordering, frequency, JSON schema). Covers OKR 3.1 (E2E) and 3.2 (feature testing). Differential Revision: D105886676
|
@psiddh has imported this pull request. If you are a Meta employee, you can view this in D105886676. |
Summary: Tests the production LlmGenerationConfig builder API, error paths (invalid model, empty prompt, use-after-close), and callback contract (ordering, frequency, JSON schema). Covers OKR 3.1 (E2E) and 3.2 (feature testing).
Differential Revision: D105886676