Skip to content

Commit fa9bf48

Browse files
Merge pull request #117 from beehive-lab/refactor/tornadovm-planning
Refactor GPU backend planner
2 parents 7879810 + 82db094 commit fa9bf48

101 files changed

Lines changed: 3463 additions & 1846 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/build-and-run.yml

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,107 @@ jobs:
247247
configuration: standard
248248
metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-standard.json
249249

250+
- name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Prefill-Decode
251+
env:
252+
JAVA_TOOL_OPTIONS: >-
253+
-Dllama.metrics.format=json
254+
-Dllama.metrics.output=file
255+
-Dllama.metrics.file=${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-prefill-decode.json
256+
run: |
257+
cd ${{ github.workspace }}
258+
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
259+
./llama-tornado --gpu --${{ matrix.backend.name }} \
260+
--model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \
261+
--prompt "Say hello" \
262+
--with-prefill-decode
263+
python3 scripts/write_metrics_sidecar.py \
264+
--out "${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-prefill-decode.meta.json" \
265+
backend="${{ matrix.backend.name }}" \
266+
task=llama-inference \
267+
model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \
268+
model=Llama-3.2-1B-Instruct \
269+
quantization=Q8_0 \
270+
configuration=prefill-decode \
271+
"flags=--with-prefill-decode" \
272+
prompt="Say hello"
273+
274+
- name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Batch-Prefill-Decode
275+
env:
276+
JAVA_TOOL_OPTIONS: >-
277+
-Dllama.metrics.format=json
278+
-Dllama.metrics.output=file
279+
-Dllama.metrics.file=${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-batch-prefill-decode.json
280+
run: |
281+
cd ${{ github.workspace }}
282+
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
283+
./llama-tornado --gpu --${{ matrix.backend.name }} \
284+
--model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \
285+
--prompt "Say hello" \
286+
--with-prefill-decode --batch-prefill-size 32
287+
python3 scripts/write_metrics_sidecar.py \
288+
--out "${{ runner.temp }}/metrics-${{ matrix.backend.name }}-llama-1b-q8-batch-prefill-decode.meta.json" \
289+
backend="${{ matrix.backend.name }}" \
290+
task=llama-inference \
291+
model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \
292+
model=Llama-3.2-1B-Instruct \
293+
quantization=Q8_0 \
294+
configuration=batch-prefill-decode \
295+
"flags=--with-prefill-decode --batch-prefill-size 32" \
296+
prompt="Say hello"
297+
298+
# ── PTX-only: CUDA-graph variants ────────────────────────────────────────
299+
- name: PTX - Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Prefill-Decode-CUDA-Graphs
300+
if: matrix.backend.name == 'ptx'
301+
env:
302+
JAVA_TOOL_OPTIONS: >-
303+
-Dllama.metrics.format=json
304+
-Dllama.metrics.output=file
305+
-Dllama.metrics.file=${{ runner.temp }}/metrics-ptx-llama-1b-q8-prefill-decode-cuda-graphs.json
306+
run: |
307+
cd ${{ github.workspace }}
308+
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
309+
./llama-tornado --gpu --ptx \
310+
--model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \
311+
--prompt "Say hello" \
312+
--with-prefill-decode \
313+
--cuda-graphs
314+
python3 scripts/write_metrics_sidecar.py \
315+
--out "${{ runner.temp }}/metrics-ptx-llama-1b-q8-prefill-decode-cuda-graphs.meta.json" \
316+
backend=ptx \
317+
task=llama-inference \
318+
model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \
319+
model=Llama-3.2-1B-Instruct \
320+
quantization=Q8_0 \
321+
configuration=prefill-decode-cuda-graphs \
322+
"flags=--with-prefill-decode --cuda-graphs" \
323+
prompt="Say hello"
324+
325+
- name: PTX - Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf - Batch-Prefill-Decode-CUDA-Graphs
326+
if: matrix.backend.name == 'ptx'
327+
env:
328+
JAVA_TOOL_OPTIONS: >-
329+
-Dllama.metrics.format=json
330+
-Dllama.metrics.output=file
331+
-Dllama.metrics.file=${{ runner.temp }}/metrics-ptx-llama-1b-q8-batch-prefill-decode-cuda-graphs.json
332+
run: |
333+
cd ${{ github.workspace }}
334+
export PATH="$TORNADOVM_HOME/bin:$JAVA_HOME/bin:$PATH"
335+
./llama-tornado --gpu --ptx \
336+
--model $MODELS_DIR/Llama-3.2-1B-Instruct-Q8_0.gguf \
337+
--prompt "Say hello" \
338+
--with-prefill-decode --batch-prefill-size 32 \
339+
--cuda-graphs
340+
python3 scripts/write_metrics_sidecar.py \
341+
--out "${{ runner.temp }}/metrics-ptx-llama-1b-q8-batch-prefill-decode-cuda-graphs.meta.json" \
342+
backend=ptx \
343+
task=llama-inference \
344+
model_file=Llama-3.2-1B-Instruct-Q8_0.gguf \
345+
model=Llama-3.2-1B-Instruct \
346+
quantization=Q8_0 \
347+
configuration=batch-prefill-decode-cuda-graphs \
348+
"flags=--with-prefill-decode --batch-prefill-size 32 --cuda-graphs" \
349+
prompt="Say hello"
350+
250351
- name: Q8 - Run Qwen3-0.6B-Q8_0.gguf
251352
uses: ./.github/actions/run-inference
252353
with:

src/main/java/org/beehive/gpullama3/inference/InferenceCore.java

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,6 @@
2828
* This class provides core computational operations such as RMS normalization and forward passes through model layers. It supports both CPU and GPU implementations.
2929
* </p>
3030
*
31-
* <p>
32-
* Specifically, it implements:
33-
* <ul>
34-
* <li>{@code rmsnorm} – applies Root Mean Square Layer Normalization to input vectors</li>
35-
* <li>{@code forwardJava} – executes a Forward pass for LLaMA and Mistral models on CPU</li>
36-
* <li>{@code forwardJavaQwen3} – executes a Forward pass for Qwen3 models on CPU</li>
37-
* <li>{@code forwardTornadoVM} – executes a Forward pass using TornadoVM for GPU acceleration</li>
38-
* </ul>
39-
* </p>
4031
*/
4132

4233
public final class InferenceCore {
@@ -643,10 +634,10 @@ public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int toke
643634
* Granite uses the same transformer architecture as Llama but with maximal update parameterization (µP)
644635
* scaling factors applied at specific points:
645636
* <ul>
646-
* <li>Embedding scaling: multiply embeddings after lookup</li>
647-
* <li>Attention scaling: use custom multiplier instead of 1/sqrt(headDim)</li>
648-
* <li>Residual scaling: multiply residual connections</li>
649-
* <li>Logit scaling: divide logits by the scaling factor</li>
637+
* <li>Embedding scaling: multiply embeddings after lookup</li>
638+
* <li>Attention scaling: use custom multiplier instead of 1/sqrt(headDim)</li>
639+
* <li>Residual scaling: multiply residual connections</li>
640+
* <li>Logit scaling: divide logits by the scaling factor</li>
650641
* </ul>
651642
*/
652643
public static FloatTensor forwardGranite(Model model, State state, int token, int position) {
@@ -771,24 +762,24 @@ static void copyChunk(FloatTensor in, FloatTensor out, int dim1In, int dim1Out,
771762
*
772763
* <p>This method handles the first phase of processing a token through the transformer model:
773764
* <ol>
774-
* <li>Copies the token embedding from the model's embedding table to the state's buffer</li>
775-
* <li>Delegates the transformer layer processing to TornadoVM through the master plan</li>
765+
* <li>Copies the token embedding from the model's embedding table to the state's buffer</li>
766+
* <li>Delegates the transformer layer processing to TornadoVM through the master plan</li>
776767
* </ol>
777768
*
778769
* <p>The token embedding lookup happens on the CPU using {@link MemorySegment} operations,
779770
* while the subsequent transformer layers processing is offloaded to the accelerator through
780771
* TornadoVM for improved performance.
781772
*
782773
* @param model
783-
* The Llama model containing weights and configuration parameters
774+
* The Llama model containing weights and configuration parameters
784775
* @param state
785-
* The current execution state holding input/output tensors and temporary buffers
776+
* The current execution state holding input/output tensors and temporary buffers
786777
* @param token
787-
* The input token ID to process
778+
* The input token ID to process
788779
* @param position
789-
* The position of this token in the sequence context window
780+
* The position of this token in the sequence context window
790781
* @param tornadoVMMasterPlan
791-
* The execution plan for TornadoVM acceleration
782+
* The execution plan for TornadoVM acceleration
792783
* @return FloatTensor containing the output logits for token prediction
793784
*/
794785
public static FloatArray forwardTornadoVM(Model model, State state, int token, int position, TornadoVMMasterPlan tornadoVMMasterPlan) {
@@ -814,7 +805,7 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i
814805
default -> throw new IllegalArgumentException("Unsupported weight type: " + weights.getWeightType());
815806
}
816807

817-
return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position);
808+
return tornadoVMMasterPlan.tornadoVMForwardDecode(position);
818809
}
819810

820811
}

0 commit comments

Comments
 (0)