Add LlmModule thread safety instrumentation tests#19715
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19715
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@psiddh has exported this pull request. If you are a Meta employee, you can view the originating Diff in D105886777. |
This PR needs a
|
There was a problem hiding this comment.
Pull request overview
Adds Android instrumentation tests that exercise the documented concurrency/lifecycle contract of org.pytorch.executorch.extension.llm.LlmModule, focusing on concurrent generate() calls, cross-thread stop(), and close() idempotency.
Changes:
- Introduces a new
LlmThreadSafetyTestandroidTest suite forLlmModule. - Adds coverage for concurrent generation behavior, stop signaling from another thread, idle-stop behavior, and use-after-close semantics.
Comments suppressed due to low confidence (2)
extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmThreadSafetyTest.kt:60
- Same issue as above for the tokenizer fixture: if the resource stream is null,
.use {}will NPE before therequireNotNullmessage is evaluated. Use therequireNotNull(getResourceAsStream(...)) { ... }.use { ... }pattern to get a clear failure when android_test_setup.sh hasn’t run.
javaClass.getResourceAsStream(TOKENIZER_FILE_NAME).use { stream ->
requireNotNull(stream) {
"Test resource $TOKENIZER_FILE_NAME not found; did android_test_setup.sh run?"
}
FileUtils.copyInputStreamToFile(stream, tokenizerFile)
}
extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmThreadSafetyTest.kt:199
thresholdReached.await()/generateDone.await()are unbounded waits. If token callbacks never arrive (e.g., model load failure), this will hang until the overall test timeout. Usingawait(timeout, unit)with an assert (and possibly callingllmModule.stop()on timeout) makes failures faster and helps avoid orphaned generate threads.
// Wait for exactly TOKEN_THRESHOLD tokens, then signal stop
thresholdReached.await()
llmModule.stop()
// Wait for generate() to return
generateDone.await()
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Summary: - Differential Revision: D105886777
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 1 out of 1 changed files in this pull request and generated 2 comments.
Comments suppressed due to low confidence (4)
extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmThreadSafetyTest.kt:118
- The test uses several 20s awaits (e.g., waiting for the first token) and a 30s method timeout. Existing LLM instrumentation/perf tests allow much longer for time-to-first-token on emulator (often up to ~30s) and overall test wall-time (minutes). To avoid CI flakes, consider bumping these awaits/timeouts (or making them configurable via instrumentation args) so they can accommodate slow environments.
assertTrue(
"Thread A did not produce a token in time",
threadAProducedToken.await(20, TimeUnit.SECONDS),
)
extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmThreadSafetyTest.kt:100
- testConcurrentGenerateThrowsOrSerializes() intends to guarantee Thread A is still holding the lock when Thread B calls generate(), but the callback only counts down a latch and does not block generation. On fast devices, Thread A could finish soon after the first token, so Thread B may never contend for the lock and the test won’t actually validate serialization behavior. Consider adding a second latch to pause Thread A’s callback after the first token until Thread B has entered generate() (or until Thread B signals it is blocked), then release Thread A to complete.
override fun onResult(result: String) {
threadATokens.incrementAndGet()
threadAProducedToken.countDown()
}
extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmThreadSafetyTest.kt:136
- Catching RuntimeException here is too broad and can mask real bugs (e.g., NullPointerException) as an “expected” outcome. LlmModule.generate() failures are expected to surface as ExecutorchRuntimeException (a RuntimeException subclass) or IllegalStateException; consider catching ExecutorchRuntimeException explicitly instead of all RuntimeException so unexpected runtime errors still fail the test.
} catch (_: RuntimeException) {
// Valid: serialized second generate() may fail (e.g., dirty KV cache state)
threadBRejected.set(true)
} catch (e: Exception) {
extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmThreadSafetyTest.kt:182
- testStopFromDifferentThread() can be timing-flaky: after thresholdReached is counted down, generation continues running while the main thread wakes up and calls stop(), so on a fast device it’s possible to reach LONG_SEQ_LEN before stop() is issued, violating the final assertion. To make the contract deterministic, consider pausing generation at the threshold (e.g., have onResult await a “stopIssued” latch after counting down thresholdReached) so stop() is guaranteed to happen before additional tokens are produced.
override fun onResult(result: String) {
if (tokensReceived.incrementAndGet() == TOKEN_THRESHOLD) {
thresholdReached.countDown()
}
}
|
@psiddh has imported this pull request. If you are a Meta employee, you can view this in D105886777. |
Summary: -
Differential Revision: D105886777