-
Notifications
You must be signed in to change notification settings - Fork 1k
Add LlmModule thread safety instrumentation tests #19715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+297
−0
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
297 changes: 297 additions & 0 deletions
297
...oid/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmThreadSafetyTest.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,297 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|
|
||
| package org.pytorch.executorch | ||
|
|
||
| import androidx.test.ext.junit.runners.AndroidJUnit4 | ||
| import androidx.test.filters.LargeTest | ||
| import java.io.File | ||
| import java.io.IOException | ||
| import java.util.concurrent.CountDownLatch | ||
| import java.util.concurrent.TimeUnit | ||
| import java.util.concurrent.atomic.AtomicBoolean | ||
| import java.util.concurrent.atomic.AtomicInteger | ||
| import org.apache.commons.io.FileUtils | ||
| import org.junit.After | ||
| import org.junit.Assert.assertFalse | ||
| import org.junit.Assert.assertTrue | ||
| import org.junit.Assert.fail | ||
| import org.junit.Before | ||
| import org.junit.Test | ||
| import org.junit.runner.RunWith | ||
| import org.pytorch.executorch.TestFileUtils.getTestFilePath | ||
| import org.pytorch.executorch.extension.llm.LlmCallback | ||
| import org.pytorch.executorch.extension.llm.LlmModule | ||
|
|
||
| /** | ||
| * Thread safety contract tests for [LlmModule]. | ||
| * | ||
| * Validates the concurrency contract: LlmModule uses a ReentrantLock to serialize generate() calls | ||
| * and an atomic stop flag. These tests verify defined behavior (serialization, rejection, atomic | ||
| * stop signaling, idempotent close) rather than relying on timing-dependent races. | ||
| */ | ||
| @LargeTest | ||
| @RunWith(AndroidJUnit4::class) | ||
| class LlmThreadSafetyTest { | ||
|
|
||
| private lateinit var llmModule: LlmModule | ||
|
|
||
| @Before | ||
| @Throws(IOException::class) | ||
| fun setUp() { | ||
| val pteFile = File(getTestFilePath(TEST_FILE_NAME)) | ||
| requireNotNull(javaClass.getResourceAsStream(TEST_FILE_NAME)) { | ||
| "Test resource $TEST_FILE_NAME not found; did android_test_setup.sh run?" | ||
| } | ||
| .use { stream -> FileUtils.copyInputStreamToFile(stream, pteFile) } | ||
|
|
||
| val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME)) | ||
| requireNotNull(javaClass.getResourceAsStream(TOKENIZER_FILE_NAME)) { | ||
| "Test resource $TOKENIZER_FILE_NAME not found; did android_test_setup.sh run?" | ||
| } | ||
| .use { stream -> FileUtils.copyInputStreamToFile(stream, tokenizerFile) } | ||
|
|
||
| llmModule = | ||
| LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f) | ||
| } | ||
|
|
||
| @After | ||
| fun tearDown() { | ||
| if (::llmModule.isInitialized) { | ||
| try { | ||
| llmModule.close() | ||
| } catch (_: IllegalStateException) { | ||
| // Already closed by test | ||
| } | ||
|
psiddh marked this conversation as resolved.
|
||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Contract: concurrent generate() calls are either serialized (ReentrantLock) or rejected | ||
| * (IllegalStateException). Neither outcome should crash or corrupt state. | ||
| * | ||
| * Strategy: use a latch inside Thread A's callback to guarantee it holds the lock before Thread B | ||
| * attempts generate(). | ||
| */ | ||
| @Test(timeout = 30_000) | ||
| fun testConcurrentGenerateThrowsOrSerializes() { | ||
| val threadAProducedToken = CountDownLatch(1) | ||
| val threadATokens = AtomicInteger(0) | ||
| val threadBTokens = AtomicInteger(0) | ||
| val threadBRejected = AtomicBoolean(false) | ||
| val crashed = AtomicBoolean(false) | ||
| val threadADone = CountDownLatch(1) | ||
| val threadBDone = CountDownLatch(1) | ||
|
|
||
| val threadA = Thread { | ||
| try { | ||
| llmModule.generate( | ||
| TEST_PROMPT, | ||
| LONG_SEQ_LEN, | ||
| object : LlmCallback { | ||
| override fun onResult(result: String) { | ||
| threadATokens.incrementAndGet() | ||
| threadAProducedToken.countDown() | ||
| } | ||
|
|
||
| override fun onStats(stats: String) {} | ||
| }, | ||
| ) | ||
| } catch (e: Exception) { | ||
| crashed.set(true) | ||
| } finally { | ||
| threadADone.countDown() | ||
| } | ||
| } | ||
|
|
||
| val threadB = Thread { | ||
| try { | ||
| // Wait until Thread A is actively generating (holds the lock) | ||
| assertTrue( | ||
| "Thread A did not produce a token in time", | ||
| threadAProducedToken.await(20, TimeUnit.SECONDS), | ||
| ) | ||
| llmModule.generate( | ||
|
psiddh marked this conversation as resolved.
|
||
| TEST_PROMPT, | ||
| SHORT_SEQ_LEN, | ||
| object : LlmCallback { | ||
| override fun onResult(result: String) { | ||
| threadBTokens.incrementAndGet() | ||
| } | ||
|
|
||
| override fun onStats(stats: String) {} | ||
| }, | ||
| ) | ||
| } catch (_: IllegalStateException) { | ||
| // Valid: lock rejects concurrent access | ||
| threadBRejected.set(true) | ||
| } catch (_: RuntimeException) { | ||
| // Valid: serialized second generate() may fail (e.g., dirty KV cache state) | ||
| threadBRejected.set(true) | ||
| } catch (e: Exception) { | ||
| crashed.set(true) | ||
| } finally { | ||
| threadBDone.countDown() | ||
| } | ||
| } | ||
|
|
||
| threadA.start() | ||
| threadB.start() | ||
| assertTrue("Thread A did not finish in time", threadADone.await(20, TimeUnit.SECONDS)) | ||
| assertTrue("Thread B did not finish in time", threadBDone.await(20, TimeUnit.SECONDS)) | ||
|
|
||
| assertFalse("No crash during concurrent generate() calls", crashed.get()) | ||
| // Either Thread B was rejected OR it blocked and then succeeded | ||
| val threadBSucceeded = threadBTokens.get() > 0 | ||
| assertTrue( | ||
| "Thread B must either be rejected (IllegalStateException) or serialize and succeed", | ||
| threadBRejected.get() || threadBSucceeded, | ||
| ) | ||
| assertTrue("Thread A must have produced tokens", threadATokens.get() > 0) | ||
| } | ||
|
|
||
| /** | ||
| * Contract: stop() sets an atomic flag checked between token emissions. Once set, generate() | ||
| * returns after completing the current token. | ||
| * | ||
| * Strategy: use a latch to guarantee stop() fires only after TOKEN_THRESHOLD tokens are received, | ||
| * then verify generation terminated. | ||
| */ | ||
| @Test(timeout = 30_000) | ||
| fun testStopFromDifferentThread() { | ||
| val tokensReceived = AtomicInteger(0) | ||
| val thresholdReached = CountDownLatch(1) | ||
| val generateDone = CountDownLatch(1) | ||
| val crashed = AtomicBoolean(false) | ||
|
|
||
| val generateThread = Thread { | ||
| try { | ||
| llmModule.generate( | ||
| TEST_PROMPT, | ||
| LONG_SEQ_LEN, | ||
| object : LlmCallback { | ||
| override fun onResult(result: String) { | ||
| if (tokensReceived.incrementAndGet() == TOKEN_THRESHOLD) { | ||
| thresholdReached.countDown() | ||
| } | ||
| } | ||
|
|
||
| override fun onStats(stats: String) {} | ||
| }, | ||
| ) | ||
| } catch (e: Exception) { | ||
| crashed.set(true) | ||
| } finally { | ||
| generateDone.countDown() | ||
| } | ||
| } | ||
|
|
||
| generateThread.start() | ||
|
|
||
| // Wait for exactly TOKEN_THRESHOLD tokens, then signal stop | ||
| assertTrue( | ||
| "Did not reach token threshold in time", | ||
| thresholdReached.await(20, TimeUnit.SECONDS), | ||
| ) | ||
| llmModule.stop() | ||
|
|
||
| // Wait for generate() to return | ||
| assertTrue("Generate did not finish in time", generateDone.await(20, TimeUnit.SECONDS)) | ||
|
|
||
| assertFalse("No crash during stop from different thread", crashed.get()) | ||
| assertTrue( | ||
| "Must have received at least TOKEN_THRESHOLD tokens before stop", | ||
| tokensReceived.get() >= TOKEN_THRESHOLD, | ||
| ) | ||
| assertTrue( | ||
| "Generation must terminate (not run to LONG_SEQ_LEN)", | ||
| tokensReceived.get() < LONG_SEQ_LEN, | ||
| ) | ||
| } | ||
|
|
||
| /** | ||
| * Contract: stop() when no generate() is active is a no-op. It must not corrupt state or prevent | ||
| * subsequent generate() calls from succeeding. | ||
| */ | ||
| @Test(timeout = 30_000) | ||
| fun testStopWhenIdle() { | ||
| // stop() with no active generation — should not throw | ||
| llmModule.stop() | ||
|
|
||
| // Verify the module is still functional after idle stop | ||
| val tokensReceived = AtomicInteger(0) | ||
| llmModule.generate( | ||
| TEST_PROMPT, | ||
| SHORT_SEQ_LEN, | ||
| object : LlmCallback { | ||
| override fun onResult(result: String) { | ||
| tokensReceived.incrementAndGet() | ||
| } | ||
|
|
||
| override fun onStats(stats: String) {} | ||
| }, | ||
| ) | ||
|
|
||
| assertTrue( | ||
| "generate() must still work after idle stop()", | ||
| tokensReceived.get() == SHORT_SEQ_LEN, | ||
| ) | ||
| } | ||
|
|
||
| /** | ||
| * Contract: close() is idempotent — calling it multiple times must not double-free native | ||
| * resources. After close(), generate() must throw IllegalStateException. | ||
| */ | ||
| @Test(timeout = 30_000) | ||
| fun testCloseIdempotent() { | ||
| // First, verify the module works | ||
| val tokensReceived = AtomicInteger(0) | ||
| llmModule.generate( | ||
| TEST_PROMPT, | ||
| SHORT_SEQ_LEN, | ||
| object : LlmCallback { | ||
| override fun onResult(result: String) { | ||
| tokensReceived.incrementAndGet() | ||
| } | ||
|
|
||
| override fun onStats(stats: String) {} | ||
| }, | ||
| ) | ||
| assertTrue("Module must produce tokens before close", tokensReceived.get() > 0) | ||
|
|
||
| // Close multiple times — must not crash or double-free | ||
| llmModule.close() | ||
| llmModule.close() | ||
| llmModule.close() | ||
|
|
||
| // generate() after close must throw IllegalStateException | ||
| try { | ||
| llmModule.generate( | ||
| TEST_PROMPT, | ||
| SHORT_SEQ_LEN, | ||
| object : LlmCallback { | ||
| override fun onResult(result: String) {} | ||
|
|
||
| override fun onStats(stats: String) {} | ||
| }, | ||
| ) | ||
| fail("generate() after close() must throw IllegalStateException") | ||
| } catch (_: IllegalStateException) { | ||
| // Expected | ||
| } | ||
| } | ||
|
|
||
| companion object { | ||
| private const val TEST_FILE_NAME = "/stories.pte" | ||
| private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" | ||
| private const val TEST_PROMPT = "Hello" | ||
| private const val SHORT_SEQ_LEN = 16 | ||
| private const val LONG_SEQ_LEN = 64 | ||
| private const val TOKEN_THRESHOLD = 5 | ||
| } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.