Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9ee34a1
[prf/dec] Replace `LlamaState` and `LlamaConfiguration` with generic …
orionpapadakis Jun 10, 2026
d23d10e
[prf/dec] Make State's `qDim` and `kvDim` model-agnostic for batch-p…
orionpapadakis Jun 10, 2026
474255c
[prf/dec] Add initial impl of prefill-decode and batch-prefill-decode…
orionpapadakis Jun 10, 2026
f5d002a
[prf/dec][ci] Add prefill-decode variants ci steps for Qwen3
orionpapadakis Jun 11, 2026
e0cddbe
[prf/dec][ci] Use Qwen3-0.6B instead of Qwen3-4B in CI workflows
orionpapadakis Jun 11, 2026
63bf168
[tool][WIP] Introduce tool-calling capabilities by GPULlama3.java
orionpapadakis Apr 17, 2026
12f1583
[tools] Refactor tool-calling architecture: modularize components and…
orionpapadakis Apr 18, 2026
37084ed
[introduce] Add standalone ToolCallingApp for tool-calling functional…
orionpapadakis Apr 18, 2026
6fbe743
[tools][wip] Add support for Llama 3.2 tool-call injection: batch too…
orionpapadakis Apr 21, 2026
0881f95
[tools][wip] Add fix for tool-calling
orionpapadakis Apr 28, 2026
5033dfe
[tools] Remove standalone `ToolCallingApp` and its references as it w…
orionpapadakis May 15, 2026
5f05176
[tools] Remove tool-calling classes as redundant after dropping stand…
orionpapadakis May 15, 2026
0aa5a1a
[tools] Add `supportsToolCalling` implementation and enhance tool cal…
orionpapadakis May 15, 2026
f72ba11
[tools] Extend `ToolCallExtract` with optional `id` and unify tool ca…
orionpapadakis May 15, 2026
2606f8b
[tools] Add support for batch tool call encoding across multiple chat…
orionpapadakis May 15, 2026
8295ebd
[tools] Unify tool call parsing logic, streamline method names, and e…
orionpapadakis May 15, 2026
ce83aea
Add default temperature and top-p resolution based on model formats, …
orionpapadakis May 15, 2026
76058fa
[tools] Simplify section comments in `ToolCallParserUtils` for improv…
orionpapadakis May 15, 2026
a287312
[tools][fix] Enhance JSON parsing in `ToolCallParserUtils` with strin…
orionpapadakis Jun 12, 2026
828b2f0
[tools][fix] Update Qwen3ChatFormat to encode tool results using `<to…
orionpapadakis Jun 12, 2026
fa8429e
[tools][test] Add unit tests for `ToolCallParserUtils`, covering sing…
orionpapadakis Jun 12, 2026
674170f
Add thinking on/off control support
orionpapadakis Jun 12, 2026
91333d1
Add support for canonical `<think>` token tracking and usage in think…
orionpapadakis Jun 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 98 additions & 4 deletions .github/workflows/build-and-run.yml
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,62 @@ jobs:
flags: --with-prefill-decode --batch-prefill-size 32 --cuda-graphs
metrics_file: ${{ runner.temp }}/metrics-ptx-llama-1b-f16-batch-prefill-decode-cuda-graphs.json

- name: FP16 - Run Qwen3-4B-f16.gguf
- name: FP16 - Run Qwen3-0.6B-f16.gguf
uses: ./.github/actions/run-inference
with:
backend: ${{ matrix.backend.name }}
model_file: Qwen3-4B-f16.gguf
model: Qwen3-4B
model_file: Qwen3-0.6B-f16.gguf
model: Qwen3-0.6B
quantization: F16
configuration: standard
metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-4b-f16-standard.json
metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0.6b-f16-standard.json

- name: FP16 - Run Qwen3-0.6B-f16.gguf - Prefill-Decode
uses: ./.github/actions/run-inference
with:
backend: ${{ matrix.backend.name }}
model_file: Qwen3-0.6B-f16.gguf
model: Qwen3-0.6B
quantization: F16
configuration: prefill-decode
flags: --with-prefill-decode
metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0.6b-f16-prefill-decode.json

- name: FP16 - Run Qwen3-0.6B-f16.gguf - Batch-Prefill-Decode
uses: ./.github/actions/run-inference
with:
backend: ${{ matrix.backend.name }}
model_file: Qwen3-0.6B-f16.gguf
model: Qwen3-0.6B
quantization: F16
configuration: batch-prefill-decode
flags: --with-prefill-decode --batch-prefill-size 32
metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0.6b-f16-batch-prefill-decode.json

# PTX-only: CUDA-graph variants
- name: PTX - FP16 - Run Qwen3-0.6B-f16.gguf - Prefill-Decode-CUDA-Graphs
if: matrix.backend.name == 'ptx'
uses: ./.github/actions/run-inference
with:
backend: ${{ matrix.backend.name }}
model_file: Qwen3-0.6B-f16.gguf
model: Qwen3-0.6B
quantization: F16
configuration: prefill-decode-cuda-graphs
flags: --with-prefill-decode --cuda-graphs
metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-0.6b-f16-prefill-decode-cuda-graphs.json

- name: PTX - FP16 - Run Qwen3-0.6B-f16.gguf - Batch-Prefill-Decode-CUDA-Graphs
if: matrix.backend.name == 'ptx'
uses: ./.github/actions/run-inference
with:
backend: ${{ matrix.backend.name }}
model_file: Qwen3-0.6B-f16.gguf
model: Qwen3-0.6B
quantization: F16
configuration: batch-prefill-decode-cuda-graphs
flags: --with-prefill-decode --batch-prefill-size 32 --cuda-graphs
metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-0.6b-f16-batch-prefill-decode-cuda-graphs.json

- name: FP16 - Run Mistral-7B-Instruct-v0.3.fp16.gguf
uses: ./.github/actions/run-inference
Expand Down Expand Up @@ -358,6 +405,53 @@ jobs:
configuration: standard
metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0-6b-q8-standard.json

- name: Q8 - Run Qwen3-0.6B-Q8_0.gguf - Prefill-Decode
uses: ./.github/actions/run-inference
with:
backend: ${{ matrix.backend.name }}
model_file: Qwen3-0.6B-Q8_0.gguf
model: Qwen3-0.6B
quantization: Q8_0
configuration: prefill-decode
flags: --with-prefill-decode
metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0-6b-q8-prefill-decode.json

- name: Q8 - Run Qwen3-0.6B-Q8_0.gguf - Batch-Prefill-Decode
uses: ./.github/actions/run-inference
with:
backend: ${{ matrix.backend.name }}
model_file: Qwen3-0.6B-Q8_0.gguf
model: Qwen3-0.6B
quantization: Q8_0
configuration: batch-prefill-decode
flags: --with-prefill-decode --batch-prefill-size 32
metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0-6b-q8-batch-prefill-decode.json

# PTX-only: CUDA-graph variants
- name: PTX - Q8 - Run Qwen3-0.6B-Q8_0.gguf - Prefill-Decode-CUDA-Graphs
if: matrix.backend.name == 'ptx'
uses: ./.github/actions/run-inference
with:
backend: ${{ matrix.backend.name }}
model_file: Qwen3-0.6B-Q8_0.gguf
model: Qwen3-0.6B
quantization: Q8_0
configuration: prefill-decode-cuda-graphs
flags: --with-prefill-decode --cuda-graphs
metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-0-6b-q8-prefill-decode-cuda-graphs.json

- name: PTX - Q8 - Run Qwen3-0.6B-Q8_0.gguf - Batch-Prefill-Decode-CUDA-Graphs
if: matrix.backend.name == 'ptx'
uses: ./.github/actions/run-inference
with:
backend: ${{ matrix.backend.name }}
model_file: Qwen3-0.6B-Q8_0.gguf
model: Qwen3-0.6B
quantization: Q8_0
configuration: batch-prefill-decode-cuda-graphs
flags: --with-prefill-decode --batch-prefill-size 32 --cuda-graphs
metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-0-6b-q8-batch-prefill-decode-cuda-graphs.json

- name: Q8 - Run Phi-3-mini-4k-instruct-Q8_0.gguf
uses: ./.github/actions/run-inference
with:
Expand Down
16 changes: 8 additions & 8 deletions src/main/java/org/beehive/gpullama3/Options.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ public record Options(Path modelPath, String prompt, String systemPrompt, String

public Options {
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
require(Float.isNaN(temperature) || 0 <= temperature, "Invalid argument: --temperature must be non-negative");
require(Float.isNaN(topp) || 0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
require(batchPrefillSize >= 1, "Invalid argument: --batch-prefill-size must be >= 1");
require(batchPrefillSize == 1 || withPrefillDecode, "Invalid argument: --batch-prefill-size requires --with-prefill-decode");
// Publish to system properties so TornadoVMMasterPlan and Llama read the right values
Expand Down Expand Up @@ -44,8 +44,8 @@ public static void printUsage(PrintStream out) {
out.println(" --prompt, -p <string> input prompt");
out.println(" --system-prompt, -sp <string> (optional) system prompt (Llama models)");
out.println(" --suffix <string> suffix for fill-in-the-middle request (Codestral)");
out.println(" --temperature, -temp <float> temperature in [0,inf], default 0.1");
out.println(" --top-p <float> p value in top-p (nucleus) sampling in [0,1] default 0.95");
out.println(" --temperature, -temp <float> temperature in [0,inf], default: auto-detected from model family");
out.println(" --top-p <float> p value in top-p (nucleus) sampling in [0,1], default: auto-detected from model family");
out.println(" --seed <long> random seed, default System.nanoTime()");
out.println(" --max-tokens, -n <int> number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS);
out.println(" --stream <boolean> print tokens during generation; may cause encoding artifacts for non ASCII text, default true");
Expand All @@ -59,8 +59,8 @@ public static Options getDefaultOptions() {
String prompt = "Tell me a story with Java"; // Hardcoded for testing
String systemPrompt = null;
String suffix = null;
float temperature = 0.1f;
float topp = 0.95f;
float temperature = Float.NaN; // resolved from model family after loading
float topp = Float.NaN; // resolved from model family after loading
Path modelPath = null;
long seed = System.nanoTime();
int maxTokens = DEFAULT_MAX_TOKENS;
Expand All @@ -76,8 +76,8 @@ public static Options parseOptions(String[] args) {
String prompt = "Tell me a story with Java"; // Hardcoded for testing
String systemPrompt = null;
String suffix = null;
float temperature = 0.1f;
float topp = 0.95f;
float temperature = Float.NaN; // resolved from model family after loading
float topp = Float.NaN; // resolved from model family after loading
Path modelPath = null;
long seed = System.nanoTime();
int maxTokens = DEFAULT_MAX_TOKENS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,13 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
}

static Sampler createSampler(Model model, Options options) {
return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed());
float temperature = Float.isNaN(options.temperature())
? (float) model.chatFormat().defaultTemperature()
: options.temperature();
float topp = Float.isNaN(options.topp())
? (float) model.chatFormat().defaultTopP()
: options.topp();
return selectSampler(model.configuration().vocabularySize(), temperature, topp, options.seed());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ public Qwen3State(Configuration config, int batchsize) {
this.tempKcur = new FloatArray(nEmbdHead);
}

@Override
protected int batchQDim(Configuration config) {
Qwen3Configuration q3 = (Qwen3Configuration) config;
return q3.numberOfHeadsKey() * q3.numberOfHeads();
}

@Override
protected int batchKvDim(Configuration config) {
Qwen3Configuration q3 = (Qwen3Configuration) config;
return q3.numberOfHeadsValue() * q3.numberOfKeyValueHeads();
}

@Override
protected StateFields createStateFields(Configuration configuration) {
StateFields fields = new StateFields();
Expand Down
21 changes: 16 additions & 5 deletions src/main/java/org/beehive/gpullama3/inference/state/State.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ public abstract class State {
public final HalfFloatArray embeddingXBatch; // B × dim (FP16 input)
public final FloatArray wrapXBatch; // B × dim (live activations / Q8_0 dequant)
public final HalfFloatArray wrapXbFP16Batch; // B × dim (RMSNorm output, FP16)
public final FloatArray wrapQBatch; // B × dim
public final FloatArray wrapQBatch; // B × qDim (Q projection)
public final FloatArray wrapKBatch; // B × kvDim
public final FloatArray wrapVBatch; // B × kvDim
public final FloatArray wrapXbBatch; // B × dim (attention output)
public final FloatArray wrapXbBatch; // B × qDim (attention output)
public final FloatArray wrapHbBatch; // B × hiddenDim
public final FloatArray attnScaleBatch; // B (per-token RMS scale, attn)
public final FloatArray ffnScaleBatch; // B (per-token RMS scale, FFN)
Expand Down Expand Up @@ -135,14 +135,15 @@ protected State(Configuration config, int batchsize) {

int gpuBatchSize = Integer.getInteger("llama.prefillBatchSize", 1);
if (gpuBatchSize > 1) {
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
int qDim = batchQDim(config);
int kvDim = batchKvDim(config);
this.embeddingXBatch = new HalfFloatArray(gpuBatchSize * config.dim());
this.wrapXBatch = new FloatArray(gpuBatchSize * config.dim());
this.wrapXbFP16Batch = new HalfFloatArray(gpuBatchSize * config.dim());
this.wrapQBatch = new FloatArray(gpuBatchSize * config.dim());
this.wrapQBatch = new FloatArray(gpuBatchSize * qDim);
this.wrapKBatch = new FloatArray(gpuBatchSize * kvDim);
this.wrapVBatch = new FloatArray(gpuBatchSize * kvDim);
this.wrapXbBatch = new FloatArray(gpuBatchSize * config.dim());
this.wrapXbBatch = new FloatArray(gpuBatchSize * qDim);
this.wrapHbBatch = new FloatArray(gpuBatchSize * config.hiddenDim());
this.attnScaleBatch = new FloatArray(gpuBatchSize);
this.ffnScaleBatch = new FloatArray(gpuBatchSize);
Expand All @@ -162,6 +163,16 @@ protected State(Configuration config, int batchsize) {
}
}

/** Q-projection output dimension per token (model specific: = dim for Llama; differs for Qwen3). */
protected int batchQDim(Configuration config) {
return config.dim();
}

/** KV-cache dimension per token (model specific: = dim*nHeadKv/nHeads for Llama; differs for Qwen3). */
protected int batchKvDim(Configuration config) {
return (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
}

// Abstract method - subclasses implement their specific allocation logic and sizes
protected abstract StateFields createStateFields(Configuration config);

Expand Down
Loading
Loading