From 65e793bc6355077784bd3e9fa222682155c68aa8 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Thu, 2 Apr 2026 00:36:13 -0700 Subject: [PATCH 1/5] =?UTF-8?q?Android:=20unified=20error=20reporting=20?= =?UTF-8?q?=E2=80=94=20all=20sync=20errors=20throw,=20no=20silent=20failur?= =?UTF-8?q?es?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the mixed error-reporting pattern (int return codes, empty arrays, silent no-ops) with a consistent exception-based contract across Module, LlmModule, and LlmCallback: - Module: all public methods acquire mLock and check destroyed state; loadMethod() throws ExecutorchRuntimeException on native error; execute()/forward() throw IllegalStateException on destroyed module; destroy() throws if lock unavailable (concurrent execution) - LlmModule: all generate/load/prefill methods check destroyed state via volatile flag and throw on native errors; resetNative() deprecated in favor of future close(); stop() intentionally unguarded for interrupt-during-generate; prefill methods return void instead of long - LlmCallback: onError() default logs via android.util.Log with try/catch fallback for JVM unit test environments - ExecutorchRuntimeException: added ALREADY_LOADED (0x04) error code, javadoc on all 19 error codes, "ExecuTorch" casing in error messages - JNI: renamed registrations to match Java native method names (generateNative, loadNative, resetContextNative); removed double exception throw from C++ load(); unknown input typeCode now throws - Tests: updated for void returns and assertThrows; all @Ignore preserved - Benchmark: ModelRunner and LlmModelRunner adapted to try/catch pattern Breaking change under @Experimental — all APIs are still experimental. Co-authored-by: Claude --- .../LlmModuleInstrumentationTest.kt | 5 +- .../executorch/ModuleInstrumentationTest.kt | 22 +-- .../ExecutorchRuntimeException.java | 54 ++++++- .../java/org/pytorch/executorch/Module.java | 79 +++++++--- .../executorch/extension/llm/LlmCallback.java | 8 +- .../executorch/extension/llm/LlmModule.java | 144 +++++++++++------- extension/android/jni/jni_layer.cpp | 10 +- extension/android/jni/jni_layer_llama.cpp | 25 +-- .../org/pytorch/minibench/LlmModelRunner.java | 15 +- .../org/pytorch/minibench/ModelRunner.java | 10 +- 10 files changed, 252 insertions(+), 120 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt index 4b6c3caed94..d5738773577 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt @@ -54,9 +54,7 @@ class LlmModuleInstrumentationTest : LlmCallback { @Test @Throws(IOException::class, URISyntaxException::class) fun testGenerate() { - val loadResult = llmModule.load() - // Check that the model can be load successfully - assertEquals(OK.toLong(), loadResult.toLong()) + llmModule.load() llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) assertEquals(results.size.toLong(), SEQ_LEN.toLong()) @@ -277,7 +275,6 @@ class LlmModuleInstrumentationTest : LlmCallback { private const val TEST_FILE_NAME = "/stories.pte" private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" private const val TEST_PROMPT = "Hello" - private const val OK = 0x00 private const val SEQ_LEN = 32 } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index 99d53b6dba3..408fc23f542 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -66,8 +66,7 @@ class ModuleInstrumentationTest { fun testModuleLoadMethodAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), OK.toLong()) + module.loadMethod(FORWARD_METHOD) val results = module.forward() Assert.assertTrue(results[0].isTensor) @@ -96,8 +95,7 @@ class ModuleInstrumentationTest { fun testModuleLoadMethodNonExistantMethod() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - val loadMethod = module.loadMethod(NONE_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + Assert.assertThrows(RuntimeException::class.java) { module.loadMethod(NONE_METHOD) } } @Test(expected = RuntimeException::class) @@ -105,8 +103,7 @@ class ModuleInstrumentationTest { fun testNonPteFile() { val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)) - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + module.loadMethod(FORWARD_METHOD) } @Test @@ -116,8 +113,7 @@ class ModuleInstrumentationTest { module.destroy() - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong()) + Assert.assertThrows(RuntimeException::class.java) { module.loadMethod(FORWARD_METHOD) } } @Test @@ -125,13 +121,11 @@ class ModuleInstrumentationTest { fun testForwardOnDestroyedModule() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), OK.toLong()) + module.loadMethod(FORWARD_METHOD) module.destroy() - val results = module.forward() - Assert.assertEquals(0, results.size.toLong()) + Assert.assertThrows(RuntimeException::class.java) { module.forward() } } @Ignore( @@ -175,9 +169,5 @@ class ModuleInstrumentationTest { private const val NON_PTE_FILE_NAME = "/test.txt" private const val FORWARD_METHOD = "forward" private const val NONE_METHOD = "none" - private const val OK = 0x00 - private const val INVALID_STATE = 0x2 - private const val INVALID_ARGUMENT = 0x12 - private const val ACCESS_FAILED = 0x22 } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java index 102b96ab686..e0fda73cc06 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java @@ -12,34 +12,83 @@ import java.util.HashMap; import java.util.Map; +/** + * Base exception for all ExecuTorch runtime errors. Each instance carries an integer error code + * corresponding to the native {@code runtime/core/error.h} values, accessible via {@link + * #getErrorCode()}. + */ public class ExecutorchRuntimeException extends RuntimeException { // Error code constants - keep in sync with runtime/core/error.h + // System errors + + /** Operation completed successfully. */ public static final int OK = 0x00; + + /** An unexpected internal error occurred in the runtime. */ public static final int INTERNAL = 0x01; + + /** The runtime or method is in an invalid state for the requested operation. */ public static final int INVALID_STATE = 0x02; + + /** The method has finished execution and has no more work to do. */ public static final int END_OF_METHOD = 0x03; + /** A required resource has already been loaded. */ + public static final int ALREADY_LOADED = 0x04; + // Logical errors + + /** The requested operation is not supported by this build or backend. */ public static final int NOT_SUPPORTED = 0x10; + + /** The requested operation has not been implemented. */ public static final int NOT_IMPLEMENTED = 0x11; + + /** One or more arguments passed to the operation are invalid. */ public static final int INVALID_ARGUMENT = 0x12; + + /** A value or tensor has an unexpected type. */ public static final int INVALID_TYPE = 0x13; + + /** A required operator kernel is not registered. */ public static final int OPERATOR_MISSING = 0x14; + + /** The maximum number of registered kernels has been exceeded. */ public static final int REGISTRATION_EXCEEDING_MAX_KERNELS = 0x15; + + /** A kernel with the same name is already registered. */ public static final int REGISTRATION_ALREADY_REGISTERED = 0x16; // Resource errors + + /** A required resource (file, tensor, program) was not found. */ public static final int NOT_FOUND = 0x20; + + /** A memory allocation failed. */ public static final int MEMORY_ALLOCATION_FAILED = 0x21; + + /** Access to a resource was denied or failed. */ public static final int ACCESS_FAILED = 0x22; + + /** The loaded program is malformed or incompatible. */ public static final int INVALID_PROGRAM = 0x23; + + /** External data referenced by the program is invalid or missing. */ public static final int INVALID_EXTERNAL_DATA = 0x24; + + /** The system has run out of a required resource. */ public static final int OUT_OF_RESOURCES = 0x25; // Delegate errors + + /** A delegate reported an incompatible model or configuration. */ public static final int DELEGATE_INVALID_COMPATIBILITY = 0x30; + + /** A delegate failed to allocate required memory. */ public static final int DELEGATE_MEMORY_ALLOCATION_FAILED = 0x31; + + /** A delegate received an invalid or stale handle. */ public static final int DELEGATE_INVALID_HANDLE = 0x32; private static final Map ERROR_CODE_MESSAGES; @@ -52,6 +101,7 @@ public class ExecutorchRuntimeException extends RuntimeException { map.put(INTERNAL, "Internal error"); map.put(INVALID_STATE, "Invalid state"); map.put(END_OF_METHOD, "End of method reached"); + map.put(ALREADY_LOADED, "Already loaded"); // Logical errors map.put(NOT_SUPPORTED, "Operation not supported"); map.put(NOT_IMPLEMENTED, "Operation not implemented"); @@ -83,7 +133,7 @@ static String formatMessage(int errorCode, String details) { String safeDetails = details != null ? details : "No details provided"; return String.format( - "[Executorch Error 0x%s] %s: %s", + "[ExecuTorch Error 0x%s] %s: %s", Integer.toHexString(errorCode), baseMessage, safeDetails); } @@ -111,10 +161,12 @@ public ExecutorchRuntimeException(int errorCode, String details) { this.errorCode = errorCode; } + /** Returns the numeric error code from {@code runtime/core/error.h}. */ public int getErrorCode() { return errorCode; } + /** Returns detailed log output captured from the native runtime, if available. */ public String getDetailedError() { return ErrorHelper.getDetailedErrorLogs(); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index f7e2e37dcec..05e1e5b88cf 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -8,7 +8,6 @@ package org.pytorch.executorch; -import android.util.Log; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; @@ -130,11 +129,10 @@ public EValue[] forward(EValue... inputs) { * @return return value from the method. */ public EValue[] execute(String methodName, EValue... inputs) { + mLock.lock(); try { - mLock.lock(); if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return new EValue[0]; + throw new IllegalStateException("Module has been destroyed"); } return executeNative(methodName, inputs); } finally { @@ -151,17 +149,17 @@ public EValue[] execute(String methodName, EValue... inputs) { * synchronous, and will block until the method is loaded. Therefore, it is recommended to call * this on a background thread. However, users need to make sure that they don't execute before * this function returns. - * - * @return the Error code if there was an error loading the method */ - public int loadMethod(String methodName) { + public void loadMethod(String methodName) { + mLock.lock(); try { - mLock.lock(); if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return 0x2; // InvalidState + throw new IllegalStateException("Module has been destroyed"); + } + int errorCode = loadMethodNative(methodName); + if (errorCode != 0) { + throw new ExecutorchRuntimeException(errorCode, "Failed to load method: " + methodName); } - return loadMethodNative(methodName); } finally { mLock.unlock(); } @@ -184,8 +182,20 @@ public int loadMethod(String methodName) { * * @return name of methods in this Module */ + public String[] getMethods() { + mLock.lock(); + try { + if (!mHybridData.isValid()) { + throw new IllegalStateException("Module has been destroyed"); + } + return getMethodsNative(); + } finally { + mLock.unlock(); + } + } + @DoNotStrip - public native String[] getMethods(); + private native String[] getMethodsNative(); /** * Get the corresponding @MethodMetadata for a method @@ -194,11 +204,19 @@ public int loadMethod(String methodName) { * @return @MethodMetadata for this method */ public MethodMetadata getMethodMetadata(String name) { - MethodMetadata methodMetadata = mMethodMetadata.get(name); - if (methodMetadata == null) { - throw new IllegalArgumentException("method " + name + " does not exist for this module"); + mLock.lock(); + try { + if (!mHybridData.isValid()) { + throw new IllegalStateException("Module has been destroyed"); + } + MethodMetadata methodMetadata = mMethodMetadata.get(name); + if (methodMetadata == null) { + throw new IllegalArgumentException("method " + name + " does not exist for this module"); + } + return methodMetadata; + } finally { + mLock.unlock(); } - return methodMetadata; } @DoNotStrip @@ -210,7 +228,15 @@ public static String[] readLogBufferStatic() { /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ public String[] readLogBuffer() { - return readLogBufferNative(); + mLock.lock(); + try { + if (!mHybridData.isValid()) { + throw new IllegalStateException("Module has been destroyed"); + } + return readLogBufferNative(); + } finally { + mLock.unlock(); + } } @DoNotStrip @@ -224,8 +250,20 @@ public String[] readLogBuffer() { * @return true if the etdump was successfully written, false otherwise. */ @Experimental + public boolean etdump() { + mLock.lock(); + try { + if (!mHybridData.isValid()) { + throw new IllegalStateException("Module has been destroyed"); + } + return etdumpNative(); + } finally { + mLock.unlock(); + } + } + @DoNotStrip - public native boolean etdump(); + private native boolean etdumpNative(); /** * Explicitly destroys the native Module object. Calling this method is not required, as the @@ -241,10 +279,7 @@ public void destroy() { mLock.unlock(); } } else { - Log.w( - "ExecuTorch", - "Destroy was called while the module was in use. Resources will not be immediately" - + " released."); + throw new IllegalStateException("Cannot destroy module while method is executing"); } } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java index 4e834d06721..9045fe68857 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java @@ -46,5 +46,11 @@ default void onStats(String stats) {} * @param message Human-readable error description */ @DoNotStrip - default void onError(int errorCode, String message) {} + default void onError(int errorCode, String message) { + try { + android.util.Log.e("ExecuTorch", "LLM error " + errorCode + ": " + message); + } catch (RuntimeException e) { + System.err.println("ExecuTorch LLM error " + errorCode + ": " + message); + } + } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 30831591fdd..c1ad12f7be3 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -13,6 +13,7 @@ import java.nio.ByteBuffer; import java.util.List; import org.pytorch.executorch.ExecuTorchRuntime; +import org.pytorch.executorch.ExecutorchRuntimeException; import org.pytorch.executorch.annotations.Experimental; /** @@ -29,6 +30,7 @@ public class LlmModule { public static final int MODEL_TYPE_MULTIMODAL = 2; private final HybridData mHybridData; + private volatile boolean mDestroyed = false; private static final int DEFAULT_SEQ_LEN = 128; private static final boolean DEFAULT_ECHO = true; private static final float DEFAULT_TEMPERATURE = -1.0f; @@ -153,7 +155,14 @@ public LlmModule(LlmModuleConfig config) { config.getNumEos()); } + private void checkNotDestroyed() { + if (mDestroyed) throw new IllegalStateException("LlmModule has been destroyed"); + } + + @Deprecated public void resetNative() { + if (mDestroyed) return; + mDestroyed = true; mHybridData.resetNative(); } @@ -163,8 +172,9 @@ public void resetNative() { * @param prompt Input prompt * @param llmCallback callback object to receive results. */ - public int generate(String prompt, LlmCallback llmCallback) { - return generate( + public void generate(String prompt, LlmCallback llmCallback) { + checkNotDestroyed(); + generate( prompt, DEFAULT_SEQ_LEN, llmCallback, @@ -181,8 +191,9 @@ public int generate(String prompt, LlmCallback llmCallback) { * @param seqLen sequence length * @param llmCallback callback object to receive results. */ - public int generate(String prompt, int seqLen, LlmCallback llmCallback) { - return generate( + public void generate(String prompt, int seqLen, LlmCallback llmCallback) { + checkNotDestroyed(); + generate( null, 0, 0, @@ -203,8 +214,9 @@ public int generate(String prompt, int seqLen, LlmCallback llmCallback) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public int generate(String prompt, LlmCallback llmCallback, boolean echo) { - return generate( + public void generate(String prompt, LlmCallback llmCallback, boolean echo) { + checkNotDestroyed(); + generate( null, 0, 0, @@ -226,9 +238,9 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { - return generate( - prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS); + public void generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { + checkNotDestroyed(); + generate(prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS); } /** @@ -242,7 +254,22 @@ public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean * @param numBos number of BOS tokens to prepend * @param numEos number of EOS tokens to append */ - public native int generate( + public void generate( + String prompt, + int seqLen, + LlmCallback llmCallback, + boolean echo, + float temperature, + int numBos, + int numEos) { + checkNotDestroyed(); + int err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate"); + } + } + + private native int generateNative( String prompt, int seqLen, LlmCallback llmCallback, @@ -258,13 +285,14 @@ public native int generate( * @param config the config for generation * @param llmCallback callback object to receive results */ - public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) { + public void generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) { + checkNotDestroyed(); int seqLen = config.getSeqLen(); boolean echo = config.isEcho(); float temperature = config.getTemperature(); int numBos = config.getNumBos(); int numEos = config.getNumEos(); - return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); + generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); } /** @@ -279,7 +307,7 @@ public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCa * @param llmCallback callback object to receive results. * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public int generate( + public void generate( int[] image, int width, int height, @@ -288,7 +316,8 @@ public int generate( int seqLen, LlmCallback llmCallback, boolean echo) { - return generate( + checkNotDestroyed(); + generate( image, width, height, @@ -315,7 +344,7 @@ public int generate( * @param echo indicate whether to echo the input prompt or not (text completion vs chat) * @param temperature temperature for sampling (use negative value to use module default) */ - public int generate( + public void generate( int[] image, int width, int height, @@ -325,7 +354,8 @@ public int generate( LlmCallback llmCallback, boolean echo, float temperature) { - return generate( + checkNotDestroyed(); + generate( image, width, height, @@ -354,7 +384,7 @@ public int generate( * @param numBos number of BOS tokens to prepend * @param numEos number of EOS tokens to append */ - public int generate( + public void generate( int[] image, int width, int height, @@ -366,10 +396,11 @@ public int generate( float temperature, int numBos, int numEos) { + checkNotDestroyed(); if (image != null) { prefillImages(image, width, height, channels); } - return generate(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); + generate(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); } /** @@ -379,16 +410,15 @@ public int generate( * @param width Input image width * @param height Input image height * @param channels Input image number of channels - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillImages(int[] image, int width, int height, int channels) { + public void prefillImages(int[] image, int width, int height, int channels) { + checkNotDestroyed(); int nativeResult = prefillImagesInput(image, width, height, channels); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } - return 0; } /** @@ -408,6 +438,7 @@ public long prefillImages(int[] image, int width, int height, int channels) { */ @Experimental public void prefillImages(ByteBuffer image, int width, int height, int channels) { + checkNotDestroyed(); if (!image.isDirect()) { throw new IllegalArgumentException("Input ByteBuffer must be direct."); } @@ -435,7 +466,7 @@ public void prefillImages(ByteBuffer image, int width, int height, int channels) // starting at the current position, not the base address. int nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } } @@ -459,6 +490,7 @@ public void prefillImages(ByteBuffer image, int width, int height, int channels) */ @Experimental public void prefillNormalizedImage(ByteBuffer image, int width, int height, int channels) { + checkNotDestroyed(); if (!image.isDirect()) { throw new IllegalArgumentException("Input ByteBuffer must be direct."); } @@ -501,7 +533,7 @@ public void prefillNormalizedImage(ByteBuffer image, int width, int height, int // starting at the current position, not the base address. int nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } } @@ -520,16 +552,15 @@ private native int prefillNormalizedImagesInputBuffer( * @param width Input image width * @param height Input image height * @param channels Input image number of channels - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillImages(float[] image, int width, int height, int channels) { + public void prefillImages(float[] image, int width, int height, int channels) { + checkNotDestroyed(); int nativeResult = prefillNormalizedImagesInput(image, width, height, channels); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } - return 0; } private native int prefillNormalizedImagesInput( @@ -542,16 +573,15 @@ private native int prefillNormalizedImagesInput( * @param batch_size Input batch size * @param n_bins Input number of bins * @param n_frames Input number of frames - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { + public void prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { + checkNotDestroyed(); int nativeResult = prefillAudioInput(audio, batch_size, n_bins, n_frames); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } - return 0; } private native int prefillAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames); @@ -563,16 +593,15 @@ public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) * @param batch_size Input batch size * @param n_bins Input number of bins * @param n_frames Input number of frames - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { + public void prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { + checkNotDestroyed(); int nativeResult = prefillAudioInputFloat(audio, batch_size, n_bins, n_frames); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } - return 0; } private native int prefillAudioInputFloat( @@ -585,16 +614,15 @@ private native int prefillAudioInputFloat( * @param batch_size Input batch size * @param n_channels Input number of channels * @param n_samples Input number of samples - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { + public void prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { + checkNotDestroyed(); int nativeResult = prefillRawAudioInput(audio, batch_size, n_channels, n_samples); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } - return 0; } private native int prefillRawAudioInput( @@ -604,16 +632,15 @@ private native int prefillRawAudioInput( * Prefill the KV cache with the given text prompt. * * @param prompt The text prompt to prefill. - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillPrompt(String prompt) { + public void prefillPrompt(String prompt) { + checkNotDestroyed(); int nativeResult = prefillTextInput(prompt); if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); } - return 0; } // returns status @@ -624,7 +651,12 @@ public long prefillPrompt(String prompt) { * *

The startPos will be reset to 0. */ - public native void resetContext(); + public void resetContext() { + checkNotDestroyed(); + resetContextNative(); + } + + private native void resetContextNative(); /** Stop current generate() before it finishes. */ @DoNotStrip @@ -632,5 +664,13 @@ public long prefillPrompt(String prompt) { /** Force loading the module. Otherwise the model is loaded during first generate(). */ @DoNotStrip - public native int load(); + public void load() { + checkNotDestroyed(); + int err = loadNative(); + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to load model"); + } + } + + private native int loadNative(); } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index beff72119b8..88e9f9e2a12 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -385,6 +385,12 @@ class ExecuTorchJni : public facebook::jni::HybridClass { static const auto toBoolMethod = JEValue::javaClassStatic()->getMethod("toBool"); evalues.emplace_back(static_cast(toBoolMethod(jevalue))); + } else { + std::stringstream ss; + ss << "Unsupported input EValue type code: " << typeCode; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str()); + return {}; } } @@ -564,8 +570,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass { makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer), makeNativeMethod( "readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic), - makeNativeMethod("etdump", ExecuTorchJni::etdump), - makeNativeMethod("getMethods", ExecuTorchJni::getMethods), + makeNativeMethod("etdumpNative", ExecuTorchJni::etdump), + makeNativeMethod("getMethodsNative", ExecuTorchJni::getMethods), makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends), }); } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index b1474288a2f..0690cf4f30f 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -544,30 +544,17 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint load() { if (!runner_) { - std::stringstream ss; - ss << "Invalid model type category: " << model_type_category_ - << ". Valid values are: " << MODEL_TYPE_CATEGORY_LLM << " or " - << MODEL_TYPE_CATEGORY_MULTIMODAL; - executorch::jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - return -1; - } - int result = static_cast(runner_->load()); - if (result != 0) { - std::stringstream ss; - ss << "Failed to load runner: [" << result << "]"; - executorch::jni_helper::throwExecutorchException( - static_cast(result), ss.str().c_str()); - } - return result; + return static_cast(Error::InvalidArgument); + } + return static_cast(runner_->load()); } static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchLlmJni::initHybrid), - makeNativeMethod("generate", ExecuTorchLlmJni::generate), + makeNativeMethod("generateNative", ExecuTorchLlmJni::generate), makeNativeMethod("stop", ExecuTorchLlmJni::stop), - makeNativeMethod("load", ExecuTorchLlmJni::load), + makeNativeMethod("loadNative", ExecuTorchLlmJni::load), makeNativeMethod( "prefillImagesInput", ExecuTorchLlmJni::prefill_images_input), makeNativeMethod( @@ -588,7 +575,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "prefillRawAudioInput", ExecuTorchLlmJni::prefill_raw_audio_input), makeNativeMethod( "prefillTextInput", ExecuTorchLlmJni::prefill_text_input), - makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), + makeNativeMethod("resetContextNative", ExecuTorchLlmJni::reset_context), }); } }; diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java index a1b434a37bf..2c8770ca33e 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java @@ -87,10 +87,21 @@ public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) { @Override public void handleMessage(android.os.Message msg) { if (msg.what == MESSAGE_LOAD_MODEL) { - int status = mLlmModelRunner.mModule.load(); + int status = 0; + try { + mLlmModelRunner.mModule.load(); + } catch (org.pytorch.executorch.ExecutorchRuntimeException e) { + status = e.getErrorCode(); + } catch (Exception e) { + status = -1; + } mLlmModelRunner.mCallback.onModelLoaded(status); } else if (msg.what == MESSAGE_GENERATE) { - mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); + try { + mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); + } catch (Exception e) { + android.util.Log.e("LlmModelRunner", "generate() failed", e); + } mLlmModelRunner.mCallback.onGenerationStopped(); } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java index 28f4e3728f0..b2fdeed9bab 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -27,7 +27,15 @@ public void runBenchmark( long loadStart = System.nanoTime(); Module module = Module.load(model.getPath()); - int errorCode = module.loadMethod("forward"); + int errorCode = 0; + try { + module.loadMethod("forward"); + } catch (Exception e) { + errorCode = + (e instanceof org.pytorch.executorch.ExecutorchRuntimeException) + ? ((org.pytorch.executorch.ExecutorchRuntimeException) e).getErrorCode() + : -1; + } long loadEnd = System.nanoTime(); for (int i = 0; i < numWarmupIter; i++) { From 4338e136f74b9c9a8ea9959423add572c8ed1bad Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Thu, 2 Apr 2026 06:28:11 -0700 Subject: [PATCH 2/5] Update extension/android/jni/jni_layer_llama.cpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- extension/android/jni/jni_layer_llama.cpp | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 0690cf4f30f..5f582a089da 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -544,9 +544,21 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint load() { if (!runner_) { + ET_LOG( + Error, + "ExecuTorchLlmJni::load() called but runner_ is null. " + "The model runner was not created or failed to initialize due to a " + "previous configuration or initialization error."); return static_cast(Error::InvalidArgument); } - return static_cast(runner_->load()); + const auto load_result = static_cast(runner_->load()); + if (load_result != static_cast(Error::Ok)) { + ET_LOG( + Error, + "ExecuTorchLlmJni::load() failed in runner_->load() with error code %d.", + static_cast(load_result)); + } + return load_result; } static void registerNatives() { From e734382545c591c7dcc46d56ef9679b8d659de38 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Thu, 2 Apr 2026 06:28:34 -0700 Subject: [PATCH 3/5] Update extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../org/pytorch/executorch/ModuleInstrumentationTest.kt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index 408fc23f542..366ec50b767 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -95,7 +95,12 @@ class ModuleInstrumentationTest { fun testModuleLoadMethodNonExistantMethod() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - Assert.assertThrows(RuntimeException::class.java) { module.loadMethod(NONE_METHOD) } + val exception = + Assert.assertThrows(ExecutorchRuntimeException::class.java) { + module.loadMethod(NONE_METHOD) + } + Assert.assertEquals( + ExecutorchRuntimeException.ErrorCode.INVALID_ARGUMENT, exception.getErrorCode()) } @Test(expected = RuntimeException::class) From 28241aa30a893bbd90158ee2f5107140ee81164f Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Thu, 2 Apr 2026 06:28:42 -0700 Subject: [PATCH 4/5] Update extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../java/org/pytorch/executorch/ModuleInstrumentationTest.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index 366ec50b767..7a193328c32 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -118,7 +118,7 @@ class ModuleInstrumentationTest { module.destroy() - Assert.assertThrows(RuntimeException::class.java) { module.loadMethod(FORWARD_METHOD) } + Assert.assertThrows(IllegalStateException::class.java) { module.loadMethod(FORWARD_METHOD) } } @Test From 795e56ca4fb56d417885f5e56059cb58533a4b98 Mon Sep 17 00:00:00 2001 From: Siddartha Pothapragada Date: Thu, 2 Apr 2026 06:29:10 -0700 Subject: [PATCH 5/5] Update extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../java/org/pytorch/executorch/extension/llm/LlmCallback.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java index 9045fe68857..ec0413caf2e 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java @@ -49,7 +49,7 @@ default void onStats(String stats) {} default void onError(int errorCode, String message) { try { android.util.Log.e("ExecuTorch", "LLM error " + errorCode + ": " + message); - } catch (RuntimeException e) { + } catch (Throwable t) { System.err.println("ExecuTorch LLM error " + errorCode + ": " + message); } }