From 9ee34a1eb1a4f9086af3cc7e20ec8fda8135ef4e Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Wed, 10 Jun 2026 11:16:04 +0300 Subject: [PATCH 01/23] [prf/dec] Replace `LlamaState` and `LlamaConfiguration` with generic `State` and `Configuration` in batch-prefill and batch-decode activations --- .../plan/components/activation/BatchDecodeActivation.java | 8 ++++---- .../components/activation/BatchPrefillActivation.java | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java index 5b63530d..ea9fb1fe 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.tornadovm.plan.components.activation; -import org.beehive.gpullama3.inference.state.LlamaState; -import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; @@ -26,14 +26,14 @@ public class BatchDecodeActivation implements ActivationTaskGraph { private final ImmutableTaskGraph itg; private final int dim; - public BatchDecodeActivation(LlamaState state, LlamaConfiguration config, String lastBatchLayerId, boolean isQ8) { + public BatchDecodeActivation(State state, Configuration config, String lastBatchLayerId, boolean isQ8) { this.dim = config.dim(); KernelContext ctx = new KernelContext(); this.itg = buildGraph(ctx, state, lastBatchLayerId, isQ8).snapshot(); } // @formatter:off - private TaskGraph buildGraph(KernelContext ctx, LlamaState state, + private TaskGraph buildGraph(KernelContext ctx, State state, String lastBatchLayerId, boolean isQ8) { TaskGraph tg = new TaskGraph("decodeActivation") .consumeFromDevice(lastBatchLayerId, state.wrapKeyCache, state.wrapValueCache) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java index 17b80cfa..aa1b02d7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.tornadovm.plan.components.activation; -import org.beehive.gpullama3.inference.state.LlamaState; -import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; @@ -29,7 +29,7 @@ public class BatchPrefillActivation implements ActivationTaskGraph { private final int batchSize; private final int dim; - public BatchPrefillActivation(LlamaState state, LlamaConfiguration config, int batchSize, boolean isQ8) { + public BatchPrefillActivation(State state, Configuration config, int batchSize, boolean isQ8) { this.isQ8 = isQ8; this.batchSize = batchSize; this.dim = config.dim(); @@ -37,7 +37,7 @@ public BatchPrefillActivation(LlamaState state, LlamaConfiguration config, int b this.itg = buildGraph(ctx, state).snapshot(); } - private TaskGraph buildGraph(KernelContext ctx, LlamaState state) { + private TaskGraph buildGraph(KernelContext ctx, State state) { if (isQ8) { return new TaskGraph("prefillActivation") .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapXBatch) From d23d10e6c1605227547d8b6204e3b7e13b1f8120 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Wed, 10 Jun 2026 11:49:54 +0300 Subject: [PATCH 02/23] [prf/dec] Make State's `qDim` and `kvDim` model-agnostic for batch-prefill activations --- .../gpullama3/inference/state/Qwen3State.java | 12 +++++++++++ .../gpullama3/inference/state/State.java | 21 ++++++++++++++----- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java index d70625b9..c84ccf96 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java @@ -35,6 +35,18 @@ public Qwen3State(Configuration config, int batchsize) { this.tempKcur = new FloatArray(nEmbdHead); } + @Override + protected int batchQDim(Configuration config) { + Qwen3Configuration q3 = (Qwen3Configuration) config; + return q3.numberOfHeadsKey() * q3.numberOfHeads(); + } + + @Override + protected int batchKvDim(Configuration config) { + Qwen3Configuration q3 = (Qwen3Configuration) config; + return q3.numberOfHeadsValue() * q3.numberOfKeyValueHeads(); + } + @Override protected StateFields createStateFields(Configuration configuration) { StateFields fields = new StateFields(); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index 4e448508..0807b756 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -77,10 +77,10 @@ public abstract class State { public final HalfFloatArray embeddingXBatch; // B × dim (FP16 input) public final FloatArray wrapXBatch; // B × dim (live activations / Q8_0 dequant) public final HalfFloatArray wrapXbFP16Batch; // B × dim (RMSNorm output, FP16) - public final FloatArray wrapQBatch; // B × dim + public final FloatArray wrapQBatch; // B × qDim (Q projection) public final FloatArray wrapKBatch; // B × kvDim public final FloatArray wrapVBatch; // B × kvDim - public final FloatArray wrapXbBatch; // B × dim (attention output) + public final FloatArray wrapXbBatch; // B × qDim (attention output) public final FloatArray wrapHbBatch; // B × hiddenDim public final FloatArray attnScaleBatch; // B (per-token RMS scale, attn) public final FloatArray ffnScaleBatch; // B (per-token RMS scale, FFN) @@ -135,14 +135,15 @@ protected State(Configuration config, int batchsize) { int gpuBatchSize = Integer.getInteger("llama.prefillBatchSize", 1); if (gpuBatchSize > 1) { - int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int qDim = batchQDim(config); + int kvDim = batchKvDim(config); this.embeddingXBatch = new HalfFloatArray(gpuBatchSize * config.dim()); this.wrapXBatch = new FloatArray(gpuBatchSize * config.dim()); this.wrapXbFP16Batch = new HalfFloatArray(gpuBatchSize * config.dim()); - this.wrapQBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapQBatch = new FloatArray(gpuBatchSize * qDim); this.wrapKBatch = new FloatArray(gpuBatchSize * kvDim); this.wrapVBatch = new FloatArray(gpuBatchSize * kvDim); - this.wrapXbBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapXbBatch = new FloatArray(gpuBatchSize * qDim); this.wrapHbBatch = new FloatArray(gpuBatchSize * config.hiddenDim()); this.attnScaleBatch = new FloatArray(gpuBatchSize); this.ffnScaleBatch = new FloatArray(gpuBatchSize); @@ -162,6 +163,16 @@ protected State(Configuration config, int batchsize) { } } + /** Q-projection output dimension per token (model specific: = dim for Llama; differs for Qwen3). */ + protected int batchQDim(Configuration config) { + return config.dim(); + } + + /** KV-cache dimension per token (model specific: = dim*nHeadKv/nHeads for Llama; differs for Qwen3). */ + protected int batchKvDim(Configuration config) { + return (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + } + // Abstract method - subclasses implement their specific allocation logic and sizes protected abstract StateFields createStateFields(Configuration config); From 474255c527c55fbe45c5209dfd069d8847ae6c53 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Wed, 10 Jun 2026 11:53:03 +0300 Subject: [PATCH 03/23] [prf/dec] Add initial impl of prefill-decode and batch-prefill-decode for Qwen3 models and FP16 and Q8_0 quantizations --- .../beehive/gpullama3/model/qwen3/Qwen3.java | 6 +- .../tornadovm/kernels/Qwen3Kernels.java | 320 ++++++++++++++++++ .../layers/type/fp16/Qwen3FP16FFNLayers.java | 34 +- .../fp16/decode/Qwen3FP16FFNLayersDecode.java | 59 ++++ .../Qwen3FP16FFNLayersPrefillDecode.java | 48 +++ .../prefill/Qwen3FP16LayersBatchPrefill.java | 253 ++++++++++++++ .../layers/type/q8_0/Qwen3Q8_0FFNLayers.java | 35 +- .../q8_0/decode/Qwen3Q8_0FFNLayersDecode.java | 59 ++++ .../Qwen3Q8_0FFNLayersPrefillDecode.java | 45 +++ .../prefill/Qwen3Q8_0LayersBatchPrefill.java | 250 ++++++++++++++ .../tornadovm/plan/ForwardPlanFactory.java | 18 +- .../fp16/Qwen3FP16PlanComponents.java | 52 ++- .../q8_0/Qwen3Q8_0PlanComponents.java | 52 ++- 13 files changed, 1207 insertions(+), 24 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersDecode.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersPrefillDecode.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/Qwen3FP16LayersBatchPrefill.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersDecode.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersPrefillDecode.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/Qwen3Q8_0LayersBatchPrefill.java diff --git a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java index d178be7c..5904316f 100644 --- a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java +++ b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java @@ -2,6 +2,8 @@ import org.beehive.gpullama3.inference.InferenceCore; import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.InferenceEngineWithBatchPrefillDecode; +import org.beehive.gpullama3.inference.InferenceEngineWithPrefillDecode; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.Qwen3State; import org.beehive.gpullama3.inference.state.State; @@ -88,10 +90,10 @@ public List generateTokens(State state, int startPosition, List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { - throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Qwen3"); + return InferenceEngineWithBatchPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } if (WITH_PREFILL_DECODE) { - throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Qwen3"); + return InferenceEngineWithPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java index ffca1e4a..26b2f2b0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java @@ -894,5 +894,325 @@ public static void fusedQKRmsNorm( } } } + + // ── Batch prefill kernels ──────────────────────────────────────────────── + + /** + * Batched fused Q/K/V projection for Qwen3 GQA (FP16 weights, FP16 input). + * + *

Like {@code TransformerBatchPrefillKernels.batchedFusedQKVMatmul} but uses + * separate {@code qDim} (Q output rows) and {@code kvDim} (K/V output rows). + * Row layout: [0, qDim) → Q; [qDim, qDim+kvDim) → K; [qDim+kvDim, qDim+2*kvDim) → V. + * Q output stride per batch = qDim; K/V output stride = kvDim.

+ * + * Worker: B*(qDim+2*kvDim) workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedQKVMatmulFP16( + KernelContext context, + HalfFloatArray xbFP16Batch, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + HalfFloatArray wq, + HalfFloatArray wk, + HalfFloatArray wv, + int inputDim, + int qDim, + int kvDim, + int localWorkGroupSize) { + + int groupId = context.globalIdx / localWorkGroupSize; + int localId = context.localIdx; + int totalRows = qDim + 2 * kvDim; + int batchIdx = groupId / totalRows; + int rowIdx = groupId % totalRows; + int inputOff = batchIdx * inputDim; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowIdx < qDim) { + int rowOff = rowIdx * inputDim; + float partial = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partial += wq.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapQBatch.set(batchIdx * qDim + rowIdx, localSum[0]); + + } else if (rowIdx < qDim + kvDim) { + int kRow = rowIdx - qDim; + int rowOff = kRow * inputDim; + float partial = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partial += wk.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapKBatch.set(batchIdx * kvDim + kRow, localSum[0]); + + } else { + int vRow = rowIdx - qDim - kvDim; + int rowOff = vRow * inputDim; + float partial = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partial += wv.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapVBatch.set(batchIdx * kvDim + vRow, localSum[0]); + } + } + + /** + * Batched fused Q/K/V projection for Qwen3 GQA (Q8_0 weights, FP32 input). + * + * Worker: B*(qDim+2*kvDim) workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedQKVMatmulQ8_0( + KernelContext context, + FloatArray wrapXbBatch, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + ByteArray wq, + ByteArray wk, + ByteArray wv, + int inputDim, + int qDim, + int kvDim, + int localWorkGroupSize) { + + int groupId = context.globalIdx / localWorkGroupSize; + int localId = context.localIdx; + int totalRows = qDim + 2 * kvDim; + int batchIdx = groupId / totalRows; + int rowIdx = groupId % totalRows; + int inputOff = batchIdx * inputDim; + + final int blockSize = 32; + final int Q8_0_BLOCK_BYTES = 34; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowIdx < qDim) { + int blocksPerRow = (inputDim + blockSize - 1) / blockSize; + int rowBlockOff = rowIdx * blocksPerRow; + float partial = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlock = j % blockSize; + int blockByteOff = (rowBlockOff + blockIdx) * Q8_0_BLOCK_BYTES; + HalfFloat sc = wq.getHalfFloat(blockByteOff); + byte q8 = wq.get(blockByteOff + 2 + withinBlock); + partial += ((float) q8 * sc.getFloat32()) * wrapXbBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapQBatch.set(batchIdx * qDim + rowIdx, localSum[0]); + + } else if (rowIdx < qDim + kvDim) { + int kRow = rowIdx - qDim; + int blocksPerRow = (inputDim + blockSize - 1) / blockSize; + int rowBlockOff = kRow * blocksPerRow; + float partial = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlock = j % blockSize; + int blockByteOff = (rowBlockOff + blockIdx) * Q8_0_BLOCK_BYTES; + HalfFloat sc = wk.getHalfFloat(blockByteOff); + byte q8 = wk.get(blockByteOff + 2 + withinBlock); + partial += ((float) q8 * sc.getFloat32()) * wrapXbBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapKBatch.set(batchIdx * kvDim + kRow, localSum[0]); + + } else { + int vRow = rowIdx - qDim - kvDim; + int blocksPerRow = (inputDim + blockSize - 1) / blockSize; + int rowBlockOff = vRow * blocksPerRow; + float partial = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlock = j % blockSize; + int blockByteOff = (rowBlockOff + blockIdx) * Q8_0_BLOCK_BYTES; + HalfFloat sc = wv.getHalfFloat(blockByteOff); + byte q8 = wv.get(blockByteOff + 2 + withinBlock); + partial += ((float) q8 * sc.getFloat32()) * wrapXbBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapVBatch.set(batchIdx * kvDim + vRow, localSum[0]); + } + } + + /** + * Batched fused Q/K RMSNorm for Qwen3 GQA. + * + *

Workgroup layout: B*(nHeads+nHeadKv) groups × nEmbdHead local threads. + * Groups [0, B*nHeads) normalize Q; groups [B*nHeads, B*(nHeads+nHeadKv)) normalize K. + * groupIdx = batchIdx*(nHeads+nHeadKv) + headSlot.

+ * + * Worker: B*(nHeads+nHeadKv) workgroups × nEmbdHead threads. + */ + public static void batchedFusedQKRmsNorm( + KernelContext context, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray qWeights, + FloatArray kWeights, + int nHeads, + int nHeadKv, + int nEmbdHead, + int qDim, + int kvDim, + float rmsNormEps) { + + int groupId = context.globalIdx / nEmbdHead; + int localId = context.localIdx; + int localSize = context.localGroupSizeX; + int totalHeadsPerBatch = nHeads + nHeadKv; + + int batchIdx = groupId / totalHeadsPerBatch; + int headSlot = groupId % totalHeadsPerBatch; + + float[] localSum = context.allocateFloatLocalArray(nEmbdHead); + + if (headSlot < nHeads) { + // Q head + int headOffset = batchIdx * qDim + headSlot * nEmbdHead; + float partialSum = 0.0f; + for (int i = localId; i < nEmbdHead; i += localSize) { + float val = wrapQBatch.get(headOffset + i); + partialSum += val * val; + } + localSum[localId] = partialSum; + context.localBarrier(); + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) localSum[localId] += localSum[localId + stride]; + context.localBarrier(); + } + float ss = localSum[0] / nEmbdHead + rmsNormEps; + ss = 1.0f / TornadoMath.sqrt(ss); + context.localBarrier(); + for (int i = localId; i < nEmbdHead; i += localSize) { + wrapQBatch.set(headOffset + i, qWeights.get(i) * ss * wrapQBatch.get(headOffset + i)); + } + } else { + // K head + int kHeadIdx = headSlot - nHeads; + int headOffset = batchIdx * kvDim + kHeadIdx * nEmbdHead; + float partialSum = 0.0f; + for (int i = localId; i < nEmbdHead; i += localSize) { + float val = wrapKBatch.get(headOffset + i); + partialSum += val * val; + } + localSum[localId] = partialSum; + context.localBarrier(); + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) localSum[localId] += localSum[localId + stride]; + context.localBarrier(); + } + float ss = localSum[0] / nEmbdHead + rmsNormEps; + ss = 1.0f / TornadoMath.sqrt(ss); + context.localBarrier(); + for (int i = localId; i < nEmbdHead; i += localSize) { + wrapKBatch.set(headOffset + i, kWeights.get(i) * ss * wrapKBatch.get(headOffset + i)); + } + } + } + + /** + * Batched fused RoPE rotation + KV cache write for Qwen3. + * + *

Like {@code TransformerBatchPrefillKernels.batchedRopeWithKVCache} but uses + * Qwen3 RoPE theta (1 000 000) and a separate {@code qDim} for the Q stride.

+ * + *

globalIdx = batchIdx*(qDim/2) + pairIdx. + * K rotation is applied only when pairIdx < kvDim/2.

+ * + * Worker: B*(qDim/2) global threads, localSize tuned like Llama RoPE. + */ + public static void batchedRopeWithKVCacheQwen3( + KernelContext context, + IntArray batchStartPosHolder, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + FloatArray wrapKeyCache, + FloatArray wrapValueCache, + int kvDim, + int nEmbdHead, + int layerIndex, + int contextLength, + int qDim) { + + int globalIdx = context.globalIdx; + int halfQDim = qDim / 2; + int batchIdx = globalIdx / halfQDim; + int pairIdx = globalIdx % halfQDim; + + int pos = batchStartPosHolder.get(0) + batchIdx; + + // Qwen3 uses split-half RoPE: pair element ic with ic + nEmbdHead/2 within each head. + int halfEmbdHead = nEmbdHead / 2; + int ic = pairIdx % halfEmbdHead; + int headIdx = pairIdx / halfEmbdHead; + + float freq = 1.0f / TornadoMath.pow(1000000.0f, 2.0f * ic / (float) nEmbdHead); + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Rotate Q (split-half pairs within each head) + int qHeadBase = batchIdx * qDim + headIdx * nEmbdHead; + float v0q = wrapQBatch.get(qHeadBase + ic); + float v1q = wrapQBatch.get(qHeadBase + ic + halfEmbdHead); + wrapQBatch.set(qHeadBase + ic, v0q * fcr - v1q * fci); + wrapQBatch.set(qHeadBase + ic + halfEmbdHead, v0q * fci + v1q * fcr); + + // Rotate K and write K,V to cache (only for KV pairs) + if (pairIdx < kvDim / 2) { + int kHeadIdx = pairIdx / halfEmbdHead; + int kHeadBase = batchIdx * kvDim + kHeadIdx * nEmbdHead; + float v0k = wrapKBatch.get(kHeadBase + ic); + float v1k = wrapKBatch.get(kHeadBase + ic + halfEmbdHead); + float rotK0 = v0k * fcr - v1k * fci; + float rotK1 = v0k * fci + v1k * fcr; + wrapKBatch.set(kHeadBase + ic, rotK0); + wrapKBatch.set(kHeadBase + ic + halfEmbdHead, rotK1); + + int cacheOff = layerIndex * contextLength * kvDim + pos * kvDim + kHeadIdx * nEmbdHead; + wrapKeyCache.set(cacheOff + ic, rotK0); + wrapKeyCache.set(cacheOff + ic + halfEmbdHead, rotK1); + wrapValueCache.set(cacheOff + ic, wrapVBatch.get(kHeadBase + ic)); + wrapValueCache.set(cacheOff + ic + halfEmbdHead, wrapVBatch.get(kHeadBase + ic + halfEmbdHead)); + } + } } // @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 0e8b21b0..b2646351 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -24,7 +24,7 @@ public class Qwen3FP16FFNLayers extends AbstractTransformerLayerTaskGraphs { // Typed reference to Qwen3-specific state - private final Qwen3State qwen3State; + protected final Qwen3State qwen3State; // Qwen3-specific GQA parameters private final int nHeadKv; private final int nEmbdHeadK; @@ -192,7 +192,12 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var unifiedLayer = new TaskGraph(taskGraphName); // === Data Setup === - unifiedLayer.consumeFromDevice(state.wrapX); + String wrapXSrc = predecessorGraphName(layerIndex); + if (wrapXSrc != null) { + unifiedLayer.consumeFromDevice(wrapXSrc, state.wrapX); + } else { + unifiedLayer.consumeFromDevice(state.wrapX); + } unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Attention weights weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS norm weights @@ -356,12 +361,27 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { config.hiddenDim(), // input dim config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC) - .persistOnDevice(qwen3State.wrapX); + .persistOnDevice(qwen3State.wrapX, qwen3State.wrapKeyCache, qwen3State.wrapValueCache); return unifiedLayer; } // @formatter:on + /** + * Returns the explicit predecessor graph name for consumeFromDevice. + * + *

The single-token plan receives {@code wrapX} (and relays all persisted buffers, + * including the KV cache) from a named predecessor graph: the activation graph for + * layer 0, the previous layer graph otherwise. The no-arg consume form looks up the + * current graph's name as the source key, which never matches in interpreter + * mode, so the persisted KV-cache buffer is not propagated and gets re-allocated every + * decode token — exhausting the device-memory pool (OOM) on long generations. + * Decode subclasses override this with their own predecessor names.

+ */ + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "activationUpdate" : "layer_" + (layerIndex - 1); + } + /** * Configure data transfers for first and subsequent layers */ @@ -377,8 +397,12 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye qwen3State.wrapKeyCache, qwen3State.wrapValueCache, // qwen3State.wrapAtt, qwen3State.wrapHb ); } else { - // Subsequent layers: Consume data from previous layer - unifiedLayer.consumeFromDevice(context, qwen3State.wrapXb, qwen3State.wrapXb2, // + // Subsequent layers: consume from the previous layer graph BY NAME. + // The no-arg consumeFromDevice form uses the current graph's own name as the + // source key, which never matches the predecessor in interpreter mode, so the + // persisted KV cache is not propagated and is re-allocated every token (OOM). + String pred = "layer_" + (layerIndex - 1); + unifiedLayer.consumeFromDevice(pred, context, qwen3State.wrapXb, qwen3State.wrapXb2, // qwen3State.wrapQ, qwen3State.wrapK, // qwen3State.wrapV, qwen3State.wrapKeyCache, // qwen3State.wrapValueCache, qwen3State.wrapAtt, // diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersDecode.java new file mode 100644 index 00000000..83c4aefc --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersDecode.java @@ -0,0 +1,59 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Decode transformer-layer TaskGraphs for the unified batched prefill-decode plan (Qwen3 FP16). + * + *

Layer 0: KV cache is consumed from "decodeActivation" (already allocated by the batch + * prefill phase). Working buffers get FIRST_EXECUTION allocation. Layers 1+: all consumed + * objects use the explicit predecessor name to satisfy TornadoVM interpreter mode.

+ * + *

Qwen3FP16FFNLayers does not use wrapXbFP16 in any task, so it is excluded.

+ */ +public class Qwen3FP16FFNLayersDecode extends Qwen3FP16FFNLayers { + + public Qwen3FP16FFNLayersDecode(String taskGraph, Qwen3State state, + Qwen3TornadoWeights weights, Qwen3Configuration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + qwen3State.wrapXb, qwen3State.wrapXb2, + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, + qwen3State.wrapAtt, qwen3State.wrapHb); + // KV cache already allocated by batch prefill; relay from decode activation graph. + layer.consumeFromDevice("decodeActivation", + qwen3State.wrapKeyCache, qwen3State.wrapValueCache); + } else { + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + qwen3State.wrapXb, qwen3State.wrapXb2, + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, + qwen3State.wrapAtt, qwen3State.wrapHb, + qwen3State.positionHolder, + qwen3State.temp, qwen3State.tempFFN); + } + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersPrefillDecode.java new file mode 100644 index 00000000..bba3d49e --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersPrefillDecode.java @@ -0,0 +1,48 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Decode transformer-layer TaskGraphs for the single-token prefill/decode plan (Qwen3 FP16). + * + *

Layer 0 delegates to the base-class {@link Qwen3FP16FFNLayers#configureLayerDataTransfers} + * which allocates wrapKeyCache/wrapValueCache with FIRST_EXECUTION. Layers 1+ consume all + * live buffers from the explicit predecessor graph to satisfy TornadoVM interpreter mode.

+ * + *

Note: Qwen3FP16FFNLayers does not use wrapXbFP16 in any kernel task, so it is + * intentionally excluded from the consume list.

+ */ +public class Qwen3FP16FFNLayersPrefillDecode extends Qwen3FP16FFNLayers { + + public Qwen3FP16FFNLayersPrefillDecode(String taskGraph, Qwen3State state, + Qwen3TornadoWeights weights, Qwen3Configuration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + return super.configureLayerDataTransfers(layer, 0); + } + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + qwen3State.wrapXb, qwen3State.wrapXb2, + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, + qwen3State.wrapAtt, qwen3State.wrapHb, + qwen3State.positionHolder); + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/Qwen3FP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/Qwen3FP16LayersBatchPrefill.java new file mode 100644 index 00000000..85b116d2 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/Qwen3FP16LayersBatchPrefill.java @@ -0,0 +1,253 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +/** + * Batched-prefill transformer-layer TaskGraphs for the Qwen3 FP16 unified batched prefill-decode plan. + * + *

Mirrors {@link LlamaFP16LayersBatchPrefill} but adapts to Qwen3's GQA layout and + * Qwen3-specific kernels (fused Q/K RMSNorm, RoPE theta = 1 000 000). Avoids any calls to + * {@code Qwen3Configuration.headSize()}, {@code kvDim()}, or {@code kvMul()} which throw.

+ */ +public class Qwen3FP16LayersBatchPrefill implements BatchPrefillTransformerLayerTaskGraphs { + + static final int LOCAL_WORK_GROUP_SIZE = 32; + + private final Qwen3State state; + private final Qwen3TornadoWeights weights; + private final Qwen3Configuration config; + private final KernelContext context = new KernelContext(); + private final int batchSize; + private final int nHeadKv; + private final int nEmbdHeadK; + private final int nEmbdHeadV; + private final int nEmbdHead; + private final int qDim; + private final int kvDim; + private final int gqa; + private final List layerITGs; + private String lastLayerTaskGraphID; + + public Qwen3FP16LayersBatchPrefill(Qwen3State state, Qwen3TornadoWeights weights, + Qwen3Configuration config, int batchSize) { + this.state = state; + this.weights = weights; + this.config = config; + this.batchSize = batchSize; + this.nHeadKv = config.numberOfKeyValueHeads(); + this.nEmbdHeadK = config.numberOfHeadsKey(); + this.nEmbdHeadV = config.numberOfHeadsValue(); + this.nEmbdHead = nEmbdHeadV; + this.qDim = nEmbdHeadK * config.numberOfHeads(); + this.kvDim = nEmbdHeadV * nHeadKv; + this.gqa = config.numberOfHeads() / nHeadKv; + this.layerITGs = IntStream.range(0, config.numberOfLayers()) + .mapToObj(this::createBatchPrefillLayerTaskGraph) + .map(TaskGraph::snapshot) + .toList(); + } + + // @formatter:off + private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { + String graphName = "batchPrefillLayer_" + layerIndex; + if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName; + + TaskGraph layer = new TaskGraph(graphName); + int dim = config.dim(); + int hidDim = config.hiddenDim(); + + // ── Data Transfers ───────────────────────────────────────────────────── + if (layerIndex == 0) { + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.attnScaleBatch, state.ffnScaleBatch, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapXbBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache); + layer.consumeFromDevice("prefillActivation", state.wrapXBatch); + } else { + String pred = "batchPrefillLayer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXBatch, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapXbBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache, + state.batchStartPosHolder, + state.attnScaleBatch, state.ffnScaleBatch); + } + + // Per-layer weights: upload once + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + + // ── Attention Block ──────────────────────────────────────────────────── + layer.task("batch_attn_rms", + TransformerBatchPrefillKernels::batchedRmsReduce, + context, state.wrapXBatch, state.attnScaleBatch, + dim, config.rmsNormEps()); + + layer.task("batch_attn_rms_apply", + TransformerBatchPrefillKernels::batchedRmsApplyFP16, + context, state.wrapXbFP16Batch, state.wrapXBatch, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + state.attnScaleBatch, dim); + + layer.task("batch_qkv", + Qwen3Kernels::batchedFusedQKVMatmulFP16, + context, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + dim, qDim, kvDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_qk_rmsnorm", + Qwen3Kernels::batchedFusedQKRmsNorm, + context, + state.wrapQBatch, state.wrapKBatch, + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), + config.numberOfHeads(), nHeadKv, nEmbdHead, + qDim, kvDim, config.rmsNormEps()); + + layer.task("batch_rope_kv", + Qwen3Kernels::batchedRopeWithKVCacheQwen3, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapKeyCache, state.wrapValueCache, + kvDim, nEmbdHead, layerIndex, config.contextLength(), qDim); + + // Reuses batchedFlashAttention: passes qDim as the 'dim' stride parameter. + // Valid because qDim == dim for all standard Qwen3 models (nEmbdHeadK = dim/nHeads). + layer.task("batch_attention", + TransformerBatchPrefillKernels::batchedFlashAttention, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKeyCache, state.wrapValueCache, + state.wrapXbBatch, + config.numberOfHeads(), nEmbdHead, + kvDim, gqa, layerIndex, config.contextLength(), qDim); + + // Output projection: n=qDim (input), d=dim (output) + layer.task("batch_attn_out", + TransformerBatchPrefillKernels::batchedMatVecWithResidual, + context, state.wrapXbBatch, state.wrapXBatch, + weights.woLayered[layerIndex].asHalfFloatArray(), + qDim, dim, LOCAL_WORK_GROUP_SIZE); + + // ── FFN Block ────────────────────────────────────────────────────────── + layer.task("batch_ffn_rms", + TransformerBatchPrefillKernels::batchedFFNRmsReduce, + context, state.wrapXBatch, state.ffnScaleBatch, + dim, config.rmsNormEps()); + + layer.task("batch_ffn_gate_up", + TransformerBatchPrefillKernels::batchedFusedRmsNormFFNGateUp, + context, state.wrapXBatch, state.wrapHbBatch, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + state.ffnScaleBatch, + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), + dim, hidDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_ffn_down", + TransformerBatchPrefillKernels::batchedMatVecWithResidual, + context, state.wrapHbBatch, state.wrapXBatch, + weights.w2Layered[layerIndex].asHalfFloatArray(), + hidDim, dim, LOCAL_WORK_GROUP_SIZE); + + layer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache); + + return layer; + } + // @formatter:on + + public void updateGridScheduler(GridScheduler scheduler) { + int dim = config.dim(); + int hidDim = config.hiddenDim(); + + WorkerGrid rmsWorker = WorkerGridFactory.genericWorker(batchSize, 1); + WorkerGrid rmsApplyWorker = WorkerGridFactory.genericWorker(batchSize * dim, 256); + + int qkvRows = qDim + 2 * kvDim; + WorkerGrid qkvWorker = WorkerGridFactory.genericWorker( + batchSize * qkvRows * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + WorkerGrid qkRmsNormWorker = WorkerGridFactory.genericWorker( + batchSize * (config.numberOfHeads() + nHeadKv) * nEmbdHead, nEmbdHead); + + int ropeGlobal = batchSize * (qDim / 2); + int ropeLocal = Math.min(512, ropeGlobal); + while (ropeLocal > 1 && ropeGlobal % ropeLocal != 0) ropeLocal--; + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(ropeGlobal, ropeLocal); + + int optLocal = findOptimalLocalSize(nEmbdHead); + WorkerGrid attnWorker = WorkerGridFactory.genericWorker( + batchSize * config.numberOfHeads() * optLocal, optLocal); + + // Wo: B*dim output rows (n=qDim, d=dim) + WorkerGrid matVecDimWorker = WorkerGridFactory.genericWorker( + batchSize * dim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + WorkerGrid matVecHidWorker = WorkerGridFactory.genericWorker( + batchSize * hidDim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + for (int i = 0; i < config.numberOfLayers(); i++) { + String p = "batchPrefillLayer_" + i + "."; + scheduler.addWorkerGrid(p + "batch_attn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_attn_rms_apply", rmsApplyWorker); + scheduler.addWorkerGrid(p + "batch_qkv", qkvWorker); + scheduler.addWorkerGrid(p + "batch_qk_rmsnorm", qkRmsNormWorker); + scheduler.addWorkerGrid(p + "batch_rope_kv", ropeWorker); + scheduler.addWorkerGrid(p + "batch_attention", attnWorker); + scheduler.addWorkerGrid(p + "batch_attn_out", matVecDimWorker); + scheduler.addWorkerGrid(p + "batch_ffn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_ffn_gate_up", matVecHidWorker); + scheduler.addWorkerGrid(p + "batch_ffn_down", matVecDimWorker); + } + } + + private static int findOptimalLocalSize(int size) { + int optimal = Math.min(size, 64); + if (size % optimal != 0) { + for (int s = 64; s >= 1; s--) { + if (size % s == 0) { optimal = s; break; } + } + } + return optimal; + } + + public List getLayerImmutableTaskGraphs() { return layerITGs; } + public String getLastLayerTaskGraphID() { return lastLayerTaskGraphID; } + public KernelContext getContext() { return context; } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index 47584da6..5e3279a1 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -28,7 +28,7 @@ public class Qwen3Q8_0FFNLayers extends AbstractTransformerLayerTaskGraphs { // Typed reference to Qwen3-specific state - private final Qwen3State qwen3State; + protected final Qwen3State qwen3State; // Qwen3-specific GQA parameters private final int nHeadKv; @@ -110,7 +110,12 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var unifiedLayer = new TaskGraph(taskGraphName); // === Data Setup === - unifiedLayer.consumeFromDevice(qwen3State.wrapX); + String wrapXSrc = predecessorGraphName(layerIndex); + if (wrapXSrc != null) { + unifiedLayer.consumeFromDevice(wrapXSrc, qwen3State.wrapX); + } else { + unifiedLayer.consumeFromDevice(qwen3State.wrapX); + } unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Attention weights weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS norm weights @@ -275,12 +280,25 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.persistOnDevice(state.wrapX); + unifiedLayer.persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); return unifiedLayer; } // @formatter:on + /** + * Returns the explicit predecessor graph name for consumeFromDevice. + * + *

The single-token plan relays all persisted buffers (including the KV cache) from a + * named predecessor graph: the activation graph for layer 0, the previous layer graph + * otherwise. The no-arg consume form does not propagate the persisted KV cache in + * interpreter mode, so it is re-allocated every decode token and exhausts the device + * memory pool (OOM) on long generations. Decode subclasses override this.

+ */ + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "activationUpdate" : "layer_" + (layerIndex - 1); + } + /** * Configure data transfers for first and subsequent layers */ @@ -304,8 +322,12 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye qwen3State.wrapAtt, qwen3State.wrapHb); } else { - // Subsequent layers: Consume data from previous layer - unifiedLayer.consumeFromDevice(context, + // Subsequent layers: consume from the previous layer graph BY NAME. The no-arg + // consume form does not propagate the persisted KV cache in interpreter mode, so + // it would be re-allocated every decode token and exhaust the memory pool (OOM). + String pred = "layer_" + (layerIndex - 1); + unifiedLayer.consumeFromDevice(pred, + context, qwen3State.wrapXb, qwen3State.wrapXb2, qwen3State.wrapQ, @@ -317,8 +339,7 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye qwen3State.wrapHb, qwen3State.positionHolder); - Qwen3State qwen3State = (Qwen3State) state; - unifiedLayer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur); // + unifiedLayer.consumeFromDevice(pred, qwen3State.tempQcur, qwen3State.tempKcur); // } return unifiedLayer; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersDecode.java new file mode 100644 index 00000000..bf77c38d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersDecode.java @@ -0,0 +1,59 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Decode transformer-layer TaskGraphs for the unified batched prefill-decode plan (Qwen3 Q8_0). + * + *

Layer 0: KV cache consumed from "decodeActivation" (already allocated by batch prefill). + * Layers 1+: all consumed objects use explicit predecessor name for interpreter mode.

+ */ +public class Qwen3Q8_0FFNLayersDecode extends Qwen3Q8_0FFNLayers { + + public Qwen3Q8_0FFNLayersDecode(String taskGraph, Qwen3State state, + Qwen3TornadoWeights weights, Qwen3Configuration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN); + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + qwen3State.tempQcur, qwen3State.tempKcur); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + qwen3State.wrapXb, qwen3State.wrapXb2, + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, + qwen3State.wrapAtt, qwen3State.wrapHb); + // KV cache already allocated by batch prefill; relay from decode activation graph. + layer.consumeFromDevice("decodeActivation", + qwen3State.wrapKeyCache, qwen3State.wrapValueCache); + } else { + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + qwen3State.wrapXb, qwen3State.wrapXb2, + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, + qwen3State.wrapAtt, qwen3State.wrapHb, + qwen3State.positionHolder, + qwen3State.temp, qwen3State.tempFFN); + layer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur); + } + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersPrefillDecode.java new file mode 100644 index 00000000..cfa9f0b8 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersPrefillDecode.java @@ -0,0 +1,45 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Decode transformer-layer TaskGraphs for the single-token prefill/decode plan (Qwen3 Q8_0). + * + *

Layer 0 delegates to the base-class which allocates wrapKeyCache/wrapValueCache with + * FIRST_EXECUTION. Layers 1+ consume all live buffers from the explicit predecessor graph.

+ */ +public class Qwen3Q8_0FFNLayersPrefillDecode extends Qwen3Q8_0FFNLayers { + + public Qwen3Q8_0FFNLayersPrefillDecode(String taskGraph, Qwen3State state, + Qwen3TornadoWeights weights, Qwen3Configuration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + return super.configureLayerDataTransfers(layer, 0); + } + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + qwen3State.wrapXb, qwen3State.wrapXb2, + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, + qwen3State.wrapAtt, qwen3State.wrapHb, + qwen3State.positionHolder); + layer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur); + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/Qwen3Q8_0LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/Qwen3Q8_0LayersBatchPrefill.java new file mode 100644 index 00000000..b3db3f41 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/Qwen3Q8_0LayersBatchPrefill.java @@ -0,0 +1,250 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +/** + * Batched-prefill transformer-layer TaskGraphs for the Qwen3 Q8_0 unified batched prefill-decode plan. + * + *

Q8_0 path: wrapXbBatch (FP32) holds normalized activations; wrapXbFP16Batch is not used. + * Mirrors {@link Qwen3FP16LayersBatchPrefill} but uses Q8_0 weights (ByteArray) and FP32 + * attention normalization path.

+ */ +public class Qwen3Q8_0LayersBatchPrefill implements BatchPrefillTransformerLayerTaskGraphs { + + static final int LOCAL_WORK_GROUP_SIZE = 32; + + private final Qwen3State state; + private final Qwen3TornadoWeights weights; + private final Qwen3Configuration config; + private final KernelContext context = new KernelContext(); + private final int batchSize; + private final int nHeadKv; + private final int nEmbdHeadK; + private final int nEmbdHeadV; + private final int nEmbdHead; + private final int qDim; + private final int kvDim; + private final int gqa; + private final List layerITGs; + private String lastLayerTaskGraphID; + + public Qwen3Q8_0LayersBatchPrefill(Qwen3State state, Qwen3TornadoWeights weights, + Qwen3Configuration config, int batchSize) { + this.state = state; + this.weights = weights; + this.config = config; + this.batchSize = batchSize; + this.nHeadKv = config.numberOfKeyValueHeads(); + this.nEmbdHeadK = config.numberOfHeadsKey(); + this.nEmbdHeadV = config.numberOfHeadsValue(); + this.nEmbdHead = nEmbdHeadV; + this.qDim = nEmbdHeadK * config.numberOfHeads(); + this.kvDim = nEmbdHeadV * nHeadKv; + this.gqa = config.numberOfHeads() / nHeadKv; + this.layerITGs = IntStream.range(0, config.numberOfLayers()) + .mapToObj(this::createBatchPrefillLayerTaskGraph) + .map(TaskGraph::snapshot) + .toList(); + } + + // @formatter:off + private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { + String graphName = "batchPrefillLayer_" + layerIndex; + if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName; + + TaskGraph layer = new TaskGraph(graphName); + int dim = config.dim(); + int hidDim = config.hiddenDim(); + + // ── Data Transfers ───────────────────────────────────────────────────── + if (layerIndex == 0) { + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.attnScaleBatch, state.ffnScaleBatch, + state.wrapXbBatch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache); + layer.consumeFromDevice("prefillActivation", state.wrapXBatch); + } else { + String pred = "batchPrefillLayer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXBatch, + state.wrapXbBatch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache, + state.batchStartPosHolder, + state.attnScaleBatch, state.ffnScaleBatch); + } + + // Per-layer weights (Q8_0 format) + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + weights.woLayered[layerIndex].asByteArray(), + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w2Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray()); + + // ── Attention Block ──────────────────────────────────────────────────── + layer.task("batch_attn_rms", + TransformerBatchPrefillKernels::batchedRmsReduce, + context, state.wrapXBatch, state.attnScaleBatch, + dim, config.rmsNormEps()); + + // FP32 normalize into wrapXbBatch (Q8_0 path: no FP16 quantize step) + layer.task("batch_attn_rms_apply", + TransformerBatchPrefillKernels::batchedRmsApplyFP32, + context, state.wrapXbBatch, state.wrapXBatch, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + state.attnScaleBatch, dim); + + layer.task("batch_qkv", + Qwen3Kernels::batchedFusedQKVMatmulQ8_0, + context, + state.wrapXbBatch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + dim, qDim, kvDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_qk_rmsnorm", + Qwen3Kernels::batchedFusedQKRmsNorm, + context, + state.wrapQBatch, state.wrapKBatch, + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), + config.numberOfHeads(), nHeadKv, nEmbdHead, + qDim, kvDim, config.rmsNormEps()); + + layer.task("batch_rope_kv", + Qwen3Kernels::batchedRopeWithKVCacheQwen3, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapKeyCache, state.wrapValueCache, + kvDim, nEmbdHead, layerIndex, config.contextLength(), qDim); + + // Reuses batchedFlashAttention; passes qDim as the 'dim' stride (valid: qDim==dim typically). + layer.task("batch_attention", + TransformerBatchPrefillKernels::batchedFlashAttention, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKeyCache, state.wrapValueCache, + state.wrapXbBatch, + config.numberOfHeads(), nEmbdHead, + kvDim, gqa, layerIndex, config.contextLength(), qDim); + + // Output projection (Q8_0): n=qDim, d=dim + layer.task("batch_attn_out", + TransformerBatchPrefillKernels::batchedMatVecWithResidualQ8, + context, state.wrapXbBatch, state.wrapXBatch, + weights.woLayered[layerIndex].asByteArray(), + qDim, dim, LOCAL_WORK_GROUP_SIZE); + + // ── FFN Block ────────────────────────────────────────────────────────── + layer.task("batch_ffn_rms", + TransformerBatchPrefillKernels::batchedFFNRmsReduce, + context, state.wrapXBatch, state.ffnScaleBatch, + dim, config.rmsNormEps()); + + layer.task("batch_ffn_gate_up", + TransformerBatchPrefillKernels::batchedFusedRmsNormFFNGateUpQ8, + context, state.wrapXBatch, state.wrapHbBatch, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + state.ffnScaleBatch, + weights.w1Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray(), + dim, hidDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_ffn_down", + TransformerBatchPrefillKernels::batchedMatVecWithResidualQ8, + context, state.wrapHbBatch, state.wrapXBatch, + weights.w2Layered[layerIndex].asByteArray(), + hidDim, dim, LOCAL_WORK_GROUP_SIZE); + + layer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache); + + return layer; + } + // @formatter:on + + public void updateGridScheduler(GridScheduler scheduler) { + int dim = config.dim(); + int hidDim = config.hiddenDim(); + + WorkerGrid rmsWorker = WorkerGridFactory.genericWorker(batchSize, 1); + WorkerGrid rmsApplyWorker = WorkerGridFactory.genericWorker(batchSize * dim, 256); + + int qkvRows = qDim + 2 * kvDim; + WorkerGrid qkvWorker = WorkerGridFactory.genericWorker( + batchSize * qkvRows * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + WorkerGrid qkRmsNormWorker = WorkerGridFactory.genericWorker( + batchSize * (config.numberOfHeads() + nHeadKv) * nEmbdHead, nEmbdHead); + + int ropeGlobal = batchSize * (qDim / 2); + int ropeLocal = Math.min(512, ropeGlobal); + while (ropeLocal > 1 && ropeGlobal % ropeLocal != 0) ropeLocal--; + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(ropeGlobal, ropeLocal); + + int optLocal = findOptimalLocalSize(nEmbdHead); + WorkerGrid attnWorker = WorkerGridFactory.genericWorker( + batchSize * config.numberOfHeads() * optLocal, optLocal); + + WorkerGrid matVecDimWorker = WorkerGridFactory.genericWorker( + batchSize * dim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + WorkerGrid matVecHidWorker = WorkerGridFactory.genericWorker( + batchSize * hidDim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + for (int i = 0; i < config.numberOfLayers(); i++) { + String p = "batchPrefillLayer_" + i + "."; + scheduler.addWorkerGrid(p + "batch_attn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_attn_rms_apply", rmsApplyWorker); + scheduler.addWorkerGrid(p + "batch_qkv", qkvWorker); + scheduler.addWorkerGrid(p + "batch_qk_rmsnorm", qkRmsNormWorker); + scheduler.addWorkerGrid(p + "batch_rope_kv", ropeWorker); + scheduler.addWorkerGrid(p + "batch_attention", attnWorker); + scheduler.addWorkerGrid(p + "batch_attn_out", matVecDimWorker); + scheduler.addWorkerGrid(p + "batch_ffn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_ffn_gate_up", matVecHidWorker); + scheduler.addWorkerGrid(p + "batch_ffn_down", matVecDimWorker); + } + } + + private static int findOptimalLocalSize(int size) { + int optimal = Math.min(size, 64); + if (size % optimal != 0) { + for (int s = 64; s >= 1; s--) { + if (size % s == 0) { optimal = s; break; } + } + } + return optimal; + } + + public List getLayerImmutableTaskGraphs() { return layerITGs; } + public String getLastLayerTaskGraphID() { return lastLayerTaskGraphID; } + public KernelContext getContext() { return context; } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java index 504eb98f..042dd354 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java @@ -176,15 +176,21 @@ private static ForwardPlan createQwen2Q8_0Plan(ExecutionMode mode, Qwen2State st } private static ForwardPlan createQwen3FP16Plan(ExecutionMode mode, Qwen3State state, Model model) { - if (mode != ExecutionMode.STANDARD) - throw new UnsupportedOperationException(mode + " not yet supported for QWEN_3 + F16"); - return new SingleTokenForwardPlan(model, new Qwen3FP16PlanComponents(state, model)); + BatchPrefillDecodeForwardPlanComponents components = new Qwen3FP16PlanComponents(state, model); + return switch (mode) { + case STANDARD -> new SingleTokenForwardPlan(model, components); + case PREFILL_DECODE -> new PrefillDecodeForwardPlan(model, components); + case BATCH_PREFILL_DECODE -> new BatchPrefillDecodeForwardPlan(model, components, TornadoVMMasterPlan.PREFILL_BATCH_SIZE); + }; } private static ForwardPlan createQwen3Q8_0Plan(ExecutionMode mode, Qwen3State state, Model model) { - if (mode != ExecutionMode.STANDARD) - throw new UnsupportedOperationException(mode + " not yet supported for QWEN_3 + Q8_0"); - return new SingleTokenForwardPlan(model, new Qwen3Q8_0PlanComponents(state, model)); + BatchPrefillDecodeForwardPlanComponents components = new Qwen3Q8_0PlanComponents(state, model); + return switch (mode) { + case STANDARD -> new SingleTokenForwardPlan(model, components); + case PREFILL_DECODE -> new PrefillDecodeForwardPlan(model, components); + case BATCH_PREFILL_DECODE -> new BatchPrefillDecodeForwardPlan(model, components, TornadoVMMasterPlan.PREFILL_BATCH_SIZE); + }; } private static ForwardPlan createPhi3FP16Plan(ExecutionMode mode, Phi3State state, Model model) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java index 5b2fe286..4cb7a0aa 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java @@ -7,14 +7,21 @@ import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; -import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.Qwen3FP16FFNLayersDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.Qwen3FP16FFNLayersPrefillDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.Qwen3FP16LayersBatchPrefill; +import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchDecodeActivation; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchPrefillActivation; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -public class Qwen3FP16PlanComponents implements SingleTokenForwardPlanComponents { +public class Qwen3FP16PlanComponents implements BatchPrefillDecodeForwardPlanComponents { private final Qwen3State state; private final Qwen3TornadoWeights weights; @@ -28,18 +35,59 @@ public Qwen3FP16PlanComponents(Qwen3State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } + // ── Activations ─────────────────────────────────────────────────────────── + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } + @Override + public ActivationTaskGraph prefillDecodeActivation() { + return new Activation("decodeActivation", state, weights, config); + } + + @Override + public ActivationTaskGraph batchPrefillActivation(int batchSize) { + return new BatchPrefillActivation(state, config, batchSize, false); + } + + @Override + public ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId) { + return new BatchDecodeActivation(state, config, lastBatchLayerId, false); + } + + // ── Transformer layer TaskGraphs ────────────────────────────────────────── + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new Qwen3FP16FFNLayers("qwen3FFN", state, weights, config, schedulerType); } + @Override + public TransformerLayerTaskGraphs prefillDecodeTransformerLayers() { + return new Qwen3FP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + } + + @Override + public TransformerLayerTaskGraphs batchDecodeTransformerLayers() { + return new Qwen3FP16FFNLayersDecode("decode", state, weights, config, schedulerType); + } + + @Override + public BatchPrefillTransformerLayerTaskGraphs batchPrefillTransformerLayers(int batchSize) { + return new Qwen3FP16LayersBatchPrefill(state, weights, config, batchSize); + } + + // ── Logits layers ───────────────────────────────────────────────────────── + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } + + @Override + public AbstractLogitsTaskGraph decodeLogits(String previousGraphId) { + return new LogitsFP16LayerDecode("logits", state, weights, config, previousGraphId, schedulerType); + } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java index 024c0364..6f067d1f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java @@ -7,14 +7,21 @@ import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; -import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.Qwen3Q8_0FFNLayersDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.Qwen3Q8_0FFNLayersPrefillDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill.Qwen3Q8_0LayersBatchPrefill; +import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchDecodeActivation; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchPrefillActivation; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -public class Qwen3Q8_0PlanComponents implements SingleTokenForwardPlanComponents { +public class Qwen3Q8_0PlanComponents implements BatchPrefillDecodeForwardPlanComponents { private final Qwen3State state; private final Qwen3TornadoWeights weights; @@ -28,18 +35,59 @@ public Qwen3Q8_0PlanComponents(Qwen3State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } + // ── Activations ─────────────────────────────────────────────────────────── + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } + @Override + public ActivationTaskGraph prefillDecodeActivation() { + return new Activation("decodeActivation", state, weights, config); + } + + @Override + public ActivationTaskGraph batchPrefillActivation(int batchSize) { + return new BatchPrefillActivation(state, config, batchSize, true); + } + + @Override + public ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId) { + return new BatchDecodeActivation(state, config, lastBatchLayerId, true); + } + + // ── Transformer layer TaskGraphs ────────────────────────────────────────── + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new Qwen3Q8_0FFNLayers("qwen3FFN", state, weights, config, schedulerType); } + @Override + public TransformerLayerTaskGraphs prefillDecodeTransformerLayers() { + return new Qwen3Q8_0FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + } + + @Override + public TransformerLayerTaskGraphs batchDecodeTransformerLayers() { + return new Qwen3Q8_0FFNLayersDecode("decode", state, weights, config, schedulerType); + } + + @Override + public BatchPrefillTransformerLayerTaskGraphs batchPrefillTransformerLayers(int batchSize) { + return new Qwen3Q8_0LayersBatchPrefill(state, weights, config, batchSize); + } + + // ── Logits layers ───────────────────────────────────────────────────────── + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } + + @Override + public AbstractLogitsTaskGraph decodeLogits(String previousGraphId) { + return new LogitsQ8_0LayerDecode("logits", state, weights, config, previousGraphId, schedulerType); + } } From f5d002a74a0be82183142de10fa5ed6fc0f9b9ad Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 11 Jun 2026 16:26:14 +0300 Subject: [PATCH 04/23] [prf/dec][ci] Add prefill-decode variants ci steps for Qwen3 --- .github/workflows/build-and-run.yml | 94 +++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index 2efab197..9a449fd9 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -187,6 +187,53 @@ jobs: configuration: standard metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-4b-f16-standard.json + - name: FP16 - Run Qwen3-4B-f16.gguf - Prefill-Decode + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-4B-f16.gguf + model: Qwen3-4B + quantization: F16 + configuration: prefill-decode + flags: --with-prefill-decode + metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-4b-f16-prefill-decode.json + + - name: FP16 - Run Qwen3-4B-f16.gguf - Batch-Prefill-Decode + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-4B-f16.gguf + model: Qwen3-4B + quantization: F16 + configuration: batch-prefill-decode + flags: --with-prefill-decode --batch-prefill-size 32 + metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-4b-f16-batch-prefill-decode.json + + # PTX-only: CUDA-graph variants + - name: PTX - FP16 - Run Qwen3-4B-f16.gguf - Prefill-Decode-CUDA-Graphs + if: matrix.backend.name == 'ptx' + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-4B-f16.gguf + model: Qwen3-4B + quantization: F16 + configuration: prefill-decode-cuda-graphs + flags: --with-prefill-decode --cuda-graphs + metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-4b-f16-prefill-decode-cuda-graphs.json + + - name: PTX - FP16 - Run Qwen3-4B-f16.gguf - Batch-Prefill-Decode-CUDA-Graphs + if: matrix.backend.name == 'ptx' + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-4B-f16.gguf + model: Qwen3-4B + quantization: F16 + configuration: batch-prefill-decode-cuda-graphs + flags: --with-prefill-decode --batch-prefill-size 32 --cuda-graphs + metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-4b-f16-batch-prefill-decode-cuda-graphs.json + - name: FP16 - Run Mistral-7B-Instruct-v0.3.fp16.gguf uses: ./.github/actions/run-inference with: @@ -358,6 +405,53 @@ jobs: configuration: standard metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0-6b-q8-standard.json + - name: Q8 - Run Qwen3-0.6B-Q8_0.gguf - Prefill-Decode + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-0.6B-Q8_0.gguf + model: Qwen3-0.6B + quantization: Q8_0 + configuration: prefill-decode + flags: --with-prefill-decode + metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0-6b-q8-prefill-decode.json + + - name: Q8 - Run Qwen3-0.6B-Q8_0.gguf - Batch-Prefill-Decode + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-0.6B-Q8_0.gguf + model: Qwen3-0.6B + quantization: Q8_0 + configuration: batch-prefill-decode + flags: --with-prefill-decode --batch-prefill-size 32 + metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0-6b-q8-batch-prefill-decode.json + + # PTX-only: CUDA-graph variants + - name: PTX - Q8 - Run Qwen3-0.6B-Q8_0.gguf - Prefill-Decode-CUDA-Graphs + if: matrix.backend.name == 'ptx' + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-0.6B-Q8_0.gguf + model: Qwen3-0.6B + quantization: Q8_0 + configuration: prefill-decode-cuda-graphs + flags: --with-prefill-decode --cuda-graphs + metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-0-6b-q8-prefill-decode-cuda-graphs.json + + - name: PTX - Q8 - Run Qwen3-0.6B-Q8_0.gguf - Batch-Prefill-Decode-CUDA-Graphs + if: matrix.backend.name == 'ptx' + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-0.6B-Q8_0.gguf + model: Qwen3-0.6B + quantization: Q8_0 + configuration: batch-prefill-decode-cuda-graphs + flags: --with-prefill-decode --batch-prefill-size 32 --cuda-graphs + metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-0-6b-q8-batch-prefill-decode-cuda-graphs.json + - name: Q8 - Run Phi-3-mini-4k-instruct-Q8_0.gguf uses: ./.github/actions/run-inference with: From e0cddbe6e3de3d346dc0c376f6e4c83378ae41cf Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 11 Jun 2026 17:37:14 +0300 Subject: [PATCH 05/23] [prf/dec][ci] Use Qwen3-0.6B instead of Qwen3-4B in CI workflows --- .github/workflows/build-and-run.yml | 40 ++++++++++++++--------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index 9a449fd9..f60083cc 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -177,62 +177,62 @@ jobs: flags: --with-prefill-decode --batch-prefill-size 32 --cuda-graphs metrics_file: ${{ runner.temp }}/metrics-ptx-llama-1b-f16-batch-prefill-decode-cuda-graphs.json - - name: FP16 - Run Qwen3-4B-f16.gguf + - name: FP16 - Run Qwen3-0.6B-f16.gguf uses: ./.github/actions/run-inference with: backend: ${{ matrix.backend.name }} - model_file: Qwen3-4B-f16.gguf - model: Qwen3-4B + model_file: Qwen3-0.6B-f16.gguf + model: Qwen3-0.6B quantization: F16 configuration: standard - metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-4b-f16-standard.json + metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0.6b-f16-standard.json - - name: FP16 - Run Qwen3-4B-f16.gguf - Prefill-Decode + - name: FP16 - Run Qwen3-0.6B-f16.gguf - Prefill-Decode uses: ./.github/actions/run-inference with: backend: ${{ matrix.backend.name }} - model_file: Qwen3-4B-f16.gguf - model: Qwen3-4B + model_file: Qwen3-0.6B-f16.gguf + model: Qwen3-0.6B quantization: F16 configuration: prefill-decode flags: --with-prefill-decode - metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-4b-f16-prefill-decode.json + metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0.6b-f16-prefill-decode.json - - name: FP16 - Run Qwen3-4B-f16.gguf - Batch-Prefill-Decode + - name: FP16 - Run Qwen3-0.6B-f16.gguf - Batch-Prefill-Decode uses: ./.github/actions/run-inference with: backend: ${{ matrix.backend.name }} - model_file: Qwen3-4B-f16.gguf - model: Qwen3-4B + model_file: Qwen3-0.6B-f16.gguf + model: Qwen3-0.6B quantization: F16 configuration: batch-prefill-decode flags: --with-prefill-decode --batch-prefill-size 32 - metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-4b-f16-batch-prefill-decode.json + metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0.6b-f16-batch-prefill-decode.json # PTX-only: CUDA-graph variants - - name: PTX - FP16 - Run Qwen3-4B-f16.gguf - Prefill-Decode-CUDA-Graphs + - name: PTX - FP16 - Run Qwen3-0.6B-f16.gguf - Prefill-Decode-CUDA-Graphs if: matrix.backend.name == 'ptx' uses: ./.github/actions/run-inference with: backend: ${{ matrix.backend.name }} - model_file: Qwen3-4B-f16.gguf - model: Qwen3-4B + model_file: Qwen3-0.6B-f16.gguf + model: Qwen3-0.6B quantization: F16 configuration: prefill-decode-cuda-graphs flags: --with-prefill-decode --cuda-graphs - metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-4b-f16-prefill-decode-cuda-graphs.json + metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-0.6b-f16-prefill-decode-cuda-graphs.json - - name: PTX - FP16 - Run Qwen3-4B-f16.gguf - Batch-Prefill-Decode-CUDA-Graphs + - name: PTX - FP16 - Run Qwen3-0.6B-f16.gguf - Batch-Prefill-Decode-CUDA-Graphs if: matrix.backend.name == 'ptx' uses: ./.github/actions/run-inference with: backend: ${{ matrix.backend.name }} - model_file: Qwen3-4B-f16.gguf - model: Qwen3-4B + model_file: Qwen3-0.6B-f16.gguf + model: Qwen3-0.6B quantization: F16 configuration: batch-prefill-decode-cuda-graphs flags: --with-prefill-decode --batch-prefill-size 32 --cuda-graphs - metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-4b-f16-batch-prefill-decode-cuda-graphs.json + metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-0.6b-f16-batch-prefill-decode-cuda-graphs.json - name: FP16 - Run Mistral-7B-Instruct-v0.3.fp16.gguf uses: ./.github/actions/run-inference From 63bf168f8afef3997f700ada7fae13bfdf305ed1 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 17 Apr 2026 15:30:36 +0300 Subject: [PATCH 06/23] [tool][WIP] Introduce tool-calling capabilities by GPULlama3.java # Conflicts: # LlamaTornadoCli.java --- llama-tornado | 15 +- llamaTornado | 15 +- .../beehive/gpullama3/ToolCallingDemo.java | 234 ++++++++++++++++++ .../gpullama3/model/format/ChatFormat.java | 53 ++++ .../model/format/LlamaChatFormat.java | 180 ++++++++++++++ .../model/format/Qwen3ChatFormat.java | 124 ++++++++++ .../model/format/ToolCallExtract.java | 11 + 7 files changed, 628 insertions(+), 4 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/ToolCallingDemo.java create mode 100644 src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java diff --git a/llama-tornado b/llama-tornado index 1d6c3d23..61439dad 100755 --- a/llama-tornado +++ b/llama-tornado @@ -191,7 +191,7 @@ class LlamaRunner: [ "-cp", self._find_llama_jar(), - "org.beehive.gpullama3.LlamaApp", + args.main_class, ] ) cmd.extend(module_config) @@ -246,6 +246,9 @@ class LlamaRunner: elif args.instruct: llama_args.append("--instruct") + if args.tool_demo: + llama_args.append("--tool-demo") + return cmd + llama_args def run(self, args: argparse.Namespace) -> int: @@ -527,6 +530,16 @@ def create_parser() -> argparse.ArgumentParser: # Advanced options advanced_group = parser.add_argument_group("Advanced Options") + advanced_group.add_argument( + "--main-class", + default="org.beehive.gpullama3.cli.LlamaTornadoCli", + help="Java main class to run (default: LlamaTornadoCli)", + ) + advanced_group.add_argument( + "--tool-demo", + action="store_true", + help="Run the tool calling demo (requires a LLaMA 3.1 or Qwen3 model)", + ) advanced_group.add_argument( "--opencl-flags", default="-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only", diff --git a/llamaTornado b/llamaTornado index 068c7946..869ddb7c 100755 --- a/llamaTornado +++ b/llamaTornado @@ -17,7 +17,8 @@ record Config( boolean printBytecodes, boolean threads, boolean printKernel, boolean fullDump, boolean verboseInit, boolean showCommand, boolean executeAfterShow, - String openclFlags, int maxWaitEvents, boolean verbose + String openclFlags, int maxWaitEvents, boolean verbose, + boolean toolDemo ) {} Config parseArgs(String[] args) { @@ -51,6 +52,7 @@ Config parseArgs(String[] args) { String openclFlags = "-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only"; int maxWaitEvents = 32000; boolean verbose = false; + boolean toolDemo = false; for (int i = 0; i < args.length; i++) { switch (args[i]) { @@ -86,6 +88,7 @@ Config parseArgs(String[] args) { case "--opencl-flags" -> openclFlags = args[++i]; case "--max-wait-events" -> maxWaitEvents = Integer.parseInt(args[++i]); case "--verbose", "-v" -> verbose = true; + case "--tool-demo" -> toolDemo = true; default -> { System.err.println("Unknown option: " + args[i]); System.exit(1); @@ -111,7 +114,8 @@ Config parseArgs(String[] args) { return new Config(modelPath, prompt, systemPrompt, temperature, topP, seed, maxTokens, stream, echo, interactive, instruct, useGpu, backend, gpuMemory, heapMin, heapMax, directMemory, debug, profiler, profilerDumpDir, printBytecodes, threads, printKernel, fullDump, - verboseInit, showCommand, executeAfterShow, openclFlags, maxWaitEvents, verbose); + verboseInit, showCommand, executeAfterShow, openclFlags, maxWaitEvents, verbose, + toolDemo); } String parseAndScale(String memoryValue, int multiplier) { @@ -169,6 +173,7 @@ void printUsage() { --show-command Display the full Java command --execute-after-show Execute after showing command --verbose, -v Verbose output + --tool-demo Run the tool calling demo -help Show this help -version Show version @@ -281,7 +286,10 @@ List buildCommand(Config cfg, String javaHome, String tornadoSdk, String } } - cmd.addAll(List.of("-cp", findLlamaJar(llamaRoot), "org.beehive.gpullama3.LlamaApp")); + var mainClass = cfg.toolDemo() + ? "org.beehive.gpullama3.ToolCallingDemo" + : "org.beehive.gpullama3.LlamaApp"; + cmd.addAll(List.of("-cp", findLlamaJar(llamaRoot), mainClass)); // LLaMA arguments cmd.addAll(List.of( @@ -298,6 +306,7 @@ List buildCommand(Config cfg, String javaHome, String tornadoSdk, String if (cfg.systemPrompt() != null) cmd.addAll(List.of("-sp", cfg.systemPrompt())); if (cfg.interactive()) cmd.add("--interactive"); else if (cfg.instruct()) cmd.add("--instruct"); + // --tool-demo is handled by main class selection above, not passed to Java return cmd; } diff --git a/src/main/java/org/beehive/gpullama3/ToolCallingDemo.java b/src/main/java/org/beehive/gpullama3/ToolCallingDemo.java new file mode 100644 index 00000000..1bee90b3 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/ToolCallingDemo.java @@ -0,0 +1,234 @@ +package org.beehive.gpullama3; + +import org.beehive.gpullama3.auxiliary.LastRunMetrics; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.format.ToolCallExtract; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.IntConsumer; + +import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; + +/** + * Standalone demo that exercises tool calling end-to-end directly against the + * GPULlama3.java inference engine — no Quarkus or LangChain4J required. + * + * Usage: + * ./llamaTornado --model /path/to/model.gguf --tool-demo + * ./llamaTornado --model /path/to/model.gguf --tool-demo --gpu --opencl + */ +public class ToolCallingDemo { + + private static final String TOOLS_JSON = """ + [ + { + "type": "function", + "function": { + "name": "list_directory", + "description": "List files and directories at the given path using ls", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Directory path to list. Defaults to current directory if omitted." + } + }, + "required": [] + } + } + } + ]"""; + + private static final String USER_PROMPT = + "Use the list_directory tool to list the files in the current directory."; + + public static void main(String[] args) throws IOException { + Options options = Options.parseOptions(ensurePrompt(args)); + + System.out.println("=== GPULlama3 Tool Calling Demo ==="); + System.out.println("Model : " + options.modelPath()); + System.out.println("GPU : " + options.useTornadovm()); + System.out.println("Prompt: " + USER_PROMPT); + System.out.println(); + + Model model = loadModel(options); + Sampler sampler = Sampler.createSampler(model, options); + ChatFormat chatFormat = model.chatFormat(); + + // ── Build prompt ────────────────────────────────────────────────────── + String toolSuffix = chatFormat.toolSystemPromptSuffix(TOOLS_JSON); + List promptTokens = new ArrayList<>(); + + if (model.shouldAddBeginOfText()) { + promptTokens.add(chatFormat.getBeginOfText()); + } + if (model.shouldAddSystemPrompt()) { + // Keep the system message minimal — just the tool instructions. + // A preamble like "You are a helpful assistant" can cause the model + // to answer from knowledge instead of calling the tool. + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.SYSTEM, toolSuffix.stripLeading()))); + } + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.USER, USER_PROMPT))); + promptTokens.addAll(chatFormat.encodeHeader( + new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + // ── First inference turn ────────────────────────────────────────────── + Set stopTokens = chatFormat.getToolAwareStopTokens(); + State state = model.createNewState(); + + System.out.println("--- Raw model output ---"); + List responseTokens = generateTokens( + model, options, state, promptTokens, stopTokens, options.maxTokens(), sampler); + System.out.println("\n--- End raw output ---\n"); + + if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { + responseTokens.removeLast(); + } + String rawResponse = model.tokenizer().decode(responseTokens); + + // ── Detect and execute tool call ────────────────────────────────────── + Optional toolCall = chatFormat.extractToolCall(rawResponse); + if (toolCall.isPresent()) { + ToolCallExtract tc = toolCall.get(); + System.out.println("✓ Tool call detected!"); + System.out.println(" Function : " + tc.name()); + System.out.println(" Arguments: " + tc.argumentsJson()); + + String toolResult = executeTool(tc); + System.out.println(" Result :\n" + toolResult); + + // ── Feed result back and get final answer ───────────────────────── + List continuation = new ArrayList<>(promptTokens); + continuation.addAll(chatFormat.encodeToolCallAssistantTurn(tc)); + continuation.addAll(chatFormat.encodeToolResultTurn(null, tc.name(), toolResult)); + // Ask the model to summarise in plain text — prevents small models from + // looping back into another tool call when they see tool defs in the system prompt. + continuation.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.USER, + "Based on the tool result above, please answer my question in plain text."))); + continuation.addAll(chatFormat.encodeHeader( + new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + State state2 = model.createNewState(); + System.out.println("--- Final answer ---"); + generateTokens(model, options, state2, continuation, + chatFormat.getStopTokens(), options.maxTokens(), sampler); + System.out.println(); + + } else { + System.out.println("✗ No tool call detected."); + System.out.println(" The model responded with plain text instead of a tool call."); + System.out.println(" Note: reliable tool calling typically requires a 7B+ model."); + System.out.println("\n Full decoded response:\n" + rawResponse); + } + + LastRunMetrics.printMetrics(); + } + + // ── Tool executor ───────────────────────────────────────────────────────── + + /** Maximum characters of tool output fed back into the prompt. */ + private static final int MAX_TOOL_RESULT_CHARS = 600; + + private static String executeTool(ToolCallExtract tc) { + return switch (tc.name()) { + case "list_directory" -> { + String path = extractStringArg(tc.argumentsJson(), "path", "."); + yield truncate(runProcess("ls", "-la", path)); + } + default -> "Unknown tool: " + tc.name(); + }; + } + + private static String runProcess(String... command) { + try { + var process = new ProcessBuilder(command) + .redirectErrorStream(true) + .start(); + String output = new String(process.getInputStream().readAllBytes()); + process.waitFor(); + return output; + } catch (Exception e) { + return "Error executing command: " + e.getMessage(); + } + } + + private static String truncate(String text) { + if (text.length() <= MAX_TOOL_RESULT_CHARS) return text; + return text.substring(0, MAX_TOOL_RESULT_CHARS) + "\n... (truncated)"; + } + + /** + * Extracts a string value from a flat JSON object by key. + * Falls back to {@code defaultValue} if the key is absent or if the value is a + * nested object (which happens when a small model echoes the schema definition + * instead of supplying an actual argument value). + */ + private static String extractStringArg(String json, String key, String defaultValue) { + String marker = "\"" + key + "\":"; + int idx = json.indexOf(marker); + if (idx == -1) return defaultValue; + int pos = idx + marker.length(); + while (pos < json.length() && Character.isWhitespace(json.charAt(pos))) pos++; + if (pos >= json.length()) return defaultValue; + // Nested object instead of a plain string — model echoed the schema, use default + if (json.charAt(pos) == '{') return defaultValue; + int valStart = json.indexOf('"', pos); + if (valStart == -1) return defaultValue; + int valEnd = json.indexOf('"', valStart + 1); + if (valEnd == -1) return defaultValue; + return json.substring(valStart + 1, valEnd); + } + + // ── Inference helpers ───────────────────────────────────────────────────── + + private static List generateTokens( + Model model, Options options, State state, + List promptTokens, Set stopTokens, + int maxTokens, Sampler sampler) { + + IntConsumer tokenConsumer = token -> { + if (model.tokenizer().shouldDisplayToken(token)) { + System.out.print(model.tokenizer().decode(List.of(token))); + System.out.flush(); + } + }; + + if (options.useTornadovm()) { + TornadoVMMasterPlan plan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); + try { + return model.generateTokensGPU( + state, 0, promptTokens, stopTokens, maxTokens, sampler, + false, tokenConsumer, plan); + } finally { + plan.freeTornadoExecutionPlan(); + } + } else { + return model.generateTokens( + state, 0, promptTokens, stopTokens, maxTokens, sampler, + false, tokenConsumer); + } + } + + private static String[] ensurePrompt(String[] args) { + for (String arg : args) { + if (arg.equals("--prompt") || arg.equals("-p")) return args; + } + String[] extended = Arrays.copyOf(args, args.length + 2); + extended[args.length] = "--prompt"; + extended[args.length + 1] = USER_PROMPT; + return extended; + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index 827ad625..486cdd54 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -8,6 +8,7 @@ import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; import java.util.List; +import java.util.Optional; import java.util.Set; public interface ChatFormat { @@ -36,6 +37,58 @@ default ChatTokens chatTokens() { Set getStopTokens(); + /** + * Returns plain text to append to the system message content when tools are available. + * The returned string is concatenated to the system message before encoding, so the + * normal {@link #encodeMessage} path handles tokenization. + * + * @param toolsJson JSON array of tool definitions, e.g. + * {@code [{"type":"function","function":{...}}]} + */ + default String toolSystemPromptSuffix(String toolsJson) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Re-encodes a prior assistant tool-call turn into the conversation token stream. + * Used when replaying multi-turn history that contains a previous tool call. + * + * @param toolCall the tool call to encode (name + raw arguments JSON) + */ + default List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Encodes a tool execution result message in the model-native format. + * + * @param toolCallId the ID of the originating tool call (may be ignored by some formats) + * @param toolName the name of the tool that was called + * @param result the result content string + */ + default List encodeToolResultTurn(String toolCallId, String toolName, String result) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Detects and extracts a tool call from fully decoded model response text. + * Returns {@link Optional#empty()} when the response is a plain text answer. + * + * @param responseText the fully decoded response from the model + */ + default Optional extractToolCall(String responseText) { + return Optional.empty(); + } + + /** + * Stop tokens to use when tool calling is enabled. + * Some models (LLaMA 3.1+) use a different end-of-turn token ({@code <|eom_id|>}) + * when emitting a tool call instead of a regular response. + */ + default Set getToolAwareStopTokens() { + return getStopTokens(); + } + record ChatTokens(String tStartHeader, String tEndHeader, String tEndOfTurn, String tEndOfText, String tEndOfTextFim) { } diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index c98a72c9..6e5f1683 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -6,6 +6,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; public class LlamaChatFormat implements ChatFormat { @@ -17,6 +18,7 @@ public class LlamaChatFormat implements ChatFormat { protected final int endOfTurn; protected final int endOfText; protected final int endOfMessage; + protected final int pythonTag; protected final Set stopTokens; public LlamaChatFormat(Tokenizer tokenizer) { @@ -28,6 +30,7 @@ public LlamaChatFormat(Tokenizer tokenizer) { this.endOfTurn = specialTokens.get("<|eot_id|>"); this.endOfText = specialTokens.get("<|end_of_text|>"); this.endOfMessage = specialTokens.getOrDefault("<|eom_id|>", -1); // only in 3.1 + this.pythonTag = specialTokens.getOrDefault("<|python_tag|>", -1); // only in 3.1 this.stopTokens = Set.of(endOfText, endOfTurn); } @@ -71,4 +74,181 @@ public List encodeDialogPrompt(boolean appendAssistantTurn, Listassistant<|end_header_id|>\n<|python_tag|>JSON<|eom_id|>} + */ + @Override + public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); + if (pythonTag != -1) { + tokens.add(pythonTag); + } + String json = "{\"name\":\"" + toolCall.name() + "\",\"parameters\":" + toolCall.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeAsList(json)); + if (endOfMessage != -1) { + tokens.add(endOfMessage); + } else { + tokens.add(endOfTurn); + } + return tokens; + } + + /** + * Encodes a tool result using the LLaMA "ipython" role. + * Format: {@code <|start_header_id|>ipython<|end_header_id|>\nresult<|eot_id|>} + */ + @Override + public List encodeToolResultTurn(String toolCallId, String toolName, String result) { + List tokens = new ArrayList<>(); + tokens.add(startHeader); + tokens.addAll(tokenizer.encodeAsList("ipython")); + tokens.add(endHeader); + tokens.addAll(tokenizer.encodeAsList("\n")); + tokens.addAll(tokenizer.encodeAsList(result)); + tokens.add(endOfTurn); + return tokens; + } + + /** + * Detects a tool call in the decoded response text. + * + *

Two formats are recognised:

+ *
    + *
  1. LLaMA 3.1 native: {@code <|python_tag|>{"name":…,"parameters":{…}}}
  2. + *
  3. Fallback: raw JSON object (possibly wrapped in markdown code fences), + * as produced by smaller models that follow the system-prompt instructions but skip + * the special token prefix.
  4. + *
+ */ + @Override + public Optional extractToolCall(String responseText) { + // ── 1. Native LLaMA 3.1 format: <|python_tag|>{...} ────────────────── + int idx = responseText.indexOf("<|python_tag|>"); + if (idx != -1) { + String json = responseText.substring(idx + "<|python_tag|>".length()).strip(); + return parseToolCallJson(json); + } + + // ── 2. Fallback: raw JSON, possibly inside markdown code fences ─────── + String stripped = stripMarkdownFences(responseText.strip()); + if (stripped.startsWith("{")) { + return parseToolCallJson(stripped); + } + + return Optional.empty(); + } + + /** + * Adds {@code <|eom_id|>} to the stop tokens when tools are enabled. + * LLaMA 3.1 ends tool-call turns with {@code <|eom_id|>} instead of {@code <|eot_id|>}. + */ + @Override + public Set getToolAwareStopTokens() { + if (endOfMessage != -1) { + return Set.of(endOfText, endOfTurn, endOfMessage); + } + return stopTokens; + } + + /** + * Strips surrounding markdown code fences (``` or ```json / ```python etc.) if present. + */ + private static String stripMarkdownFences(String text) { + if (!text.startsWith("```")) { + return text; + } + int firstNewline = text.indexOf('\n'); + if (firstNewline == -1) { + return text; + } + String body = text.substring(firstNewline + 1); + if (body.endsWith("```")) { + body = body.substring(0, body.length() - 3).stripTrailing(); + } + return body.strip(); + } + + /** + * Parses a tool call JSON object into a {@link ToolCallExtract}. + * + *

Accepts both formats produced by LLaMA variants:

+ *
    + *
  • {@code {"name":"fn","parameters":{…}}} — LLaMA 3.1 native
  • + *
  • {@code {"function":"fn","parameters":{…}}} — produced by some fine-tunes
  • + *
  • {@code {"name":"fn","arguments":{…}}} — alternative key
  • + *
+ * Uses brace-counting to extract nested argument objects correctly. + */ + private static Optional parseToolCallJson(String json) { + // ── extract tool name: try "name" then "function" ───────────────────── + String name = extractStringValue(json, "name"); + if (name == null) { + name = extractStringValue(json, "function"); + } + if (name == null) { + return Optional.empty(); + } + + // ── extract arguments object: try "parameters" then "arguments" ─────── + String argsJson = extractNestedObject(json, "parameters"); + if (argsJson == null) { + argsJson = extractNestedObject(json, "arguments"); + } + if (argsJson == null) { + argsJson = "{}"; // tool call with no arguments + } + + return Optional.of(new ToolCallExtract(name, argsJson)); + } + + /** Extracts the string value for {@code "key":""} from a JSON object. */ + private static String extractStringValue(String json, String key) { + String marker = "\"" + key + "\":"; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int quoteStart = json.indexOf('"', markerIdx + marker.length()); + if (quoteStart == -1) return null; + int quoteEnd = json.indexOf('"', quoteStart + 1); + if (quoteEnd == -1) return null; + return json.substring(quoteStart + 1, quoteEnd); + } + + /** + * Extracts the JSON object value for {@code "key":{…}} using brace-counting, + * so nested objects are handled correctly regardless of what follows. + */ + private static String extractNestedObject(String json, String key) { + String marker = "\"" + key + "\":"; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int braceStart = json.indexOf('{', markerIdx + marker.length()); + if (braceStart == -1) return null; + int depth = 0; + for (int i = braceStart; i < json.length(); i++) { + char c = json.charAt(i); + if (c == '{') depth++; + else if (c == '}') { + depth--; + if (depth == 0) return json.substring(braceStart, i + 1); + } + } + return null; // unbalanced JSON + } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index b6d2e798..c4e31270 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -129,4 +129,128 @@ public Set getStopTokens() { return stopTokens; } + + // ── Tool calling ────────────────────────────────────────────────────────── + + /** + * Qwen3 tool calling system prompt suffix. + * Appended to the system message; instructs the model to wrap tool calls in + * {@code } XML tags. + */ + @Override + public String toolSystemPromptSuffix(String toolsJson) { + return "\n\n# Tools\n\n" + + "You may call one or more functions to assist with the user query.\n\n" + + "You are provided with function signatures within XML tags:\n" + + "\n" + + toolsJson + + "\n\n\n" + + "For each function call, return a json object with function name and arguments " + + "within XML tags:\n" + + "\n" + + "{\"name\": , \"arguments\": }\n" + + ""; + } + + /** + * Re-encodes a prior assistant tool-call turn for multi-turn history. + * Format: {@code <|im_start|>assistant\n\nJSON\n<|im_end|>} + */ + @Override + public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("assistant\n")); + String json = "{\"name\":\"" + toolCall.name() + "\",\"arguments\":" + toolCall.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeOrdinaryAsList("\n" + json + "\n")); + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + + /** + * Encodes a tool result using the Qwen3 "tool" role. + * Format: {@code <|im_start|>tool\nresult<|im_end|>} + */ + @Override + public List encodeToolResultTurn(String toolCallId, String toolName, String result) { + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("tool\n")); + tokens.addAll(tokenizer.encodeOrdinaryAsList(result)); + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + + /** + * Detects a tool call enclosed in {@code } tags. + */ + @Override + public Optional extractToolCall(String responseText) { + int start = responseText.indexOf(""); + int end = responseText.lastIndexOf(""); + if (start == -1 || end == -1) { + return Optional.empty(); + } + String json = responseText.substring(start + "".length(), end).strip(); + return parseToolCallJson(json, "arguments"); + } + + /** + * Parses {@code name} and the value of {@code argsKey} out of a tool-call JSON object. + * Uses brace-counting to extract nested argument objects correctly. + * Avoids a JSON-library dependency. + */ + private static Optional parseToolCallJson(String json, String argsKey) { + // extract "name" + String name = extractStringValue(json, "name"); + if (name == null) { + return Optional.empty(); + } + + // extract arguments object using brace-counting + String argsJson = extractNestedObject(json, argsKey); + if (argsJson == null) { + argsJson = "{}"; // tool call with no arguments + } + + return Optional.of(new ToolCallExtract(name, argsJson)); + } + + /** Extracts the string value for {@code "key":""} from a JSON object. */ + private static String extractStringValue(String json, String key) { + String marker = "\"" + key + "\":"; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int quoteStart = json.indexOf('"', markerIdx + marker.length()); + if (quoteStart == -1) return null; + int quoteEnd = json.indexOf('"', quoteStart + 1); + if (quoteEnd == -1) return null; + return json.substring(quoteStart + 1, quoteEnd); + } + + /** + * Extracts the JSON object value for {@code "key":{…}} using brace-counting, + * so nested objects are handled correctly regardless of what follows. + */ + private static String extractNestedObject(String json, String key) { + String marker = "\"" + key + "\":"; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int braceStart = json.indexOf('{', markerIdx + marker.length()); + if (braceStart == -1) return null; + int depth = 0; + for (int i = braceStart; i < json.length(); i++) { + char c = json.charAt(i); + if (c == '{') depth++; + else if (c == '}') { + depth--; + if (depth == 0) return json.substring(braceStart, i + 1); + } + } + return null; // unbalanced JSON + } } diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java new file mode 100644 index 00000000..335c2176 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java @@ -0,0 +1,11 @@ +package org.beehive.gpullama3.model.format; + +/** + * Represents a single tool call extracted from a model response. + * Contains the raw strings — JSON parsing of arguments is left to the caller. + * + * @param name the tool/function name to invoke + * @param argumentsJson the arguments as a JSON object string, e.g. {"location":"Boston"} + */ +public record ToolCallExtract(String name, String argumentsJson) { +} From 12f15838417c7a99645908196e9eb8ab429a7638 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Sat, 18 Apr 2026 15:05:11 +0300 Subject: [PATCH 07/23] [tools] Refactor tool-calling architecture: modularize components and extract ToolCallingDemo --- .../beehive/gpullama3/ToolCallingDemo.java | 234 ------------------ .../model/format/LlamaChatFormat.java | 137 ++-------- .../model/format/Qwen3ChatFormat.java | 64 +---- .../model/format/ToolCallParserUtils.java | 143 +++++++++++ .../gpullama3/tools/ToolCallingOptions.java | 31 +++ .../gpullama3/tools/ToolCallingResult.java | 29 +++ .../gpullama3/tools/ToolCallingSession.java | 179 ++++++++++++++ .../gpullama3/tools/ToolDefinition.java | 17 ++ .../beehive/gpullama3/tools/ToolExecutor.java | 13 + .../beehive/gpullama3/tools/ToolRegistry.java | 94 +++++++ .../beehive/gpullama3/tools/ToolResult.java | 24 ++ 11 files changed, 549 insertions(+), 416 deletions(-) delete mode 100644 src/main/java/org/beehive/gpullama3/ToolCallingDemo.java create mode 100644 src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java create mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolResult.java diff --git a/src/main/java/org/beehive/gpullama3/ToolCallingDemo.java b/src/main/java/org/beehive/gpullama3/ToolCallingDemo.java deleted file mode 100644 index 1bee90b3..00000000 --- a/src/main/java/org/beehive/gpullama3/ToolCallingDemo.java +++ /dev/null @@ -1,234 +0,0 @@ -package org.beehive.gpullama3; - -import org.beehive.gpullama3.auxiliary.LastRunMetrics; -import org.beehive.gpullama3.inference.sampler.Sampler; -import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.format.ChatFormat; -import org.beehive.gpullama3.model.format.ToolCallExtract; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.function.IntConsumer; - -import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; - -/** - * Standalone demo that exercises tool calling end-to-end directly against the - * GPULlama3.java inference engine — no Quarkus or LangChain4J required. - * - * Usage: - * ./llamaTornado --model /path/to/model.gguf --tool-demo - * ./llamaTornado --model /path/to/model.gguf --tool-demo --gpu --opencl - */ -public class ToolCallingDemo { - - private static final String TOOLS_JSON = """ - [ - { - "type": "function", - "function": { - "name": "list_directory", - "description": "List files and directories at the given path using ls", - "parameters": { - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "Directory path to list. Defaults to current directory if omitted." - } - }, - "required": [] - } - } - } - ]"""; - - private static final String USER_PROMPT = - "Use the list_directory tool to list the files in the current directory."; - - public static void main(String[] args) throws IOException { - Options options = Options.parseOptions(ensurePrompt(args)); - - System.out.println("=== GPULlama3 Tool Calling Demo ==="); - System.out.println("Model : " + options.modelPath()); - System.out.println("GPU : " + options.useTornadovm()); - System.out.println("Prompt: " + USER_PROMPT); - System.out.println(); - - Model model = loadModel(options); - Sampler sampler = Sampler.createSampler(model, options); - ChatFormat chatFormat = model.chatFormat(); - - // ── Build prompt ────────────────────────────────────────────────────── - String toolSuffix = chatFormat.toolSystemPromptSuffix(TOOLS_JSON); - List promptTokens = new ArrayList<>(); - - if (model.shouldAddBeginOfText()) { - promptTokens.add(chatFormat.getBeginOfText()); - } - if (model.shouldAddSystemPrompt()) { - // Keep the system message minimal — just the tool instructions. - // A preamble like "You are a helpful assistant" can cause the model - // to answer from knowledge instead of calling the tool. - promptTokens.addAll(chatFormat.encodeMessage( - new ChatFormat.Message(ChatFormat.Role.SYSTEM, toolSuffix.stripLeading()))); - } - promptTokens.addAll(chatFormat.encodeMessage( - new ChatFormat.Message(ChatFormat.Role.USER, USER_PROMPT))); - promptTokens.addAll(chatFormat.encodeHeader( - new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); - - // ── First inference turn ────────────────────────────────────────────── - Set stopTokens = chatFormat.getToolAwareStopTokens(); - State state = model.createNewState(); - - System.out.println("--- Raw model output ---"); - List responseTokens = generateTokens( - model, options, state, promptTokens, stopTokens, options.maxTokens(), sampler); - System.out.println("\n--- End raw output ---\n"); - - if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { - responseTokens.removeLast(); - } - String rawResponse = model.tokenizer().decode(responseTokens); - - // ── Detect and execute tool call ────────────────────────────────────── - Optional toolCall = chatFormat.extractToolCall(rawResponse); - if (toolCall.isPresent()) { - ToolCallExtract tc = toolCall.get(); - System.out.println("✓ Tool call detected!"); - System.out.println(" Function : " + tc.name()); - System.out.println(" Arguments: " + tc.argumentsJson()); - - String toolResult = executeTool(tc); - System.out.println(" Result :\n" + toolResult); - - // ── Feed result back and get final answer ───────────────────────── - List continuation = new ArrayList<>(promptTokens); - continuation.addAll(chatFormat.encodeToolCallAssistantTurn(tc)); - continuation.addAll(chatFormat.encodeToolResultTurn(null, tc.name(), toolResult)); - // Ask the model to summarise in plain text — prevents small models from - // looping back into another tool call when they see tool defs in the system prompt. - continuation.addAll(chatFormat.encodeMessage( - new ChatFormat.Message(ChatFormat.Role.USER, - "Based on the tool result above, please answer my question in plain text."))); - continuation.addAll(chatFormat.encodeHeader( - new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); - - State state2 = model.createNewState(); - System.out.println("--- Final answer ---"); - generateTokens(model, options, state2, continuation, - chatFormat.getStopTokens(), options.maxTokens(), sampler); - System.out.println(); - - } else { - System.out.println("✗ No tool call detected."); - System.out.println(" The model responded with plain text instead of a tool call."); - System.out.println(" Note: reliable tool calling typically requires a 7B+ model."); - System.out.println("\n Full decoded response:\n" + rawResponse); - } - - LastRunMetrics.printMetrics(); - } - - // ── Tool executor ───────────────────────────────────────────────────────── - - /** Maximum characters of tool output fed back into the prompt. */ - private static final int MAX_TOOL_RESULT_CHARS = 600; - - private static String executeTool(ToolCallExtract tc) { - return switch (tc.name()) { - case "list_directory" -> { - String path = extractStringArg(tc.argumentsJson(), "path", "."); - yield truncate(runProcess("ls", "-la", path)); - } - default -> "Unknown tool: " + tc.name(); - }; - } - - private static String runProcess(String... command) { - try { - var process = new ProcessBuilder(command) - .redirectErrorStream(true) - .start(); - String output = new String(process.getInputStream().readAllBytes()); - process.waitFor(); - return output; - } catch (Exception e) { - return "Error executing command: " + e.getMessage(); - } - } - - private static String truncate(String text) { - if (text.length() <= MAX_TOOL_RESULT_CHARS) return text; - return text.substring(0, MAX_TOOL_RESULT_CHARS) + "\n... (truncated)"; - } - - /** - * Extracts a string value from a flat JSON object by key. - * Falls back to {@code defaultValue} if the key is absent or if the value is a - * nested object (which happens when a small model echoes the schema definition - * instead of supplying an actual argument value). - */ - private static String extractStringArg(String json, String key, String defaultValue) { - String marker = "\"" + key + "\":"; - int idx = json.indexOf(marker); - if (idx == -1) return defaultValue; - int pos = idx + marker.length(); - while (pos < json.length() && Character.isWhitespace(json.charAt(pos))) pos++; - if (pos >= json.length()) return defaultValue; - // Nested object instead of a plain string — model echoed the schema, use default - if (json.charAt(pos) == '{') return defaultValue; - int valStart = json.indexOf('"', pos); - if (valStart == -1) return defaultValue; - int valEnd = json.indexOf('"', valStart + 1); - if (valEnd == -1) return defaultValue; - return json.substring(valStart + 1, valEnd); - } - - // ── Inference helpers ───────────────────────────────────────────────────── - - private static List generateTokens( - Model model, Options options, State state, - List promptTokens, Set stopTokens, - int maxTokens, Sampler sampler) { - - IntConsumer tokenConsumer = token -> { - if (model.tokenizer().shouldDisplayToken(token)) { - System.out.print(model.tokenizer().decode(List.of(token))); - System.out.flush(); - } - }; - - if (options.useTornadovm()) { - TornadoVMMasterPlan plan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); - try { - return model.generateTokensGPU( - state, 0, promptTokens, stopTokens, maxTokens, sampler, - false, tokenConsumer, plan); - } finally { - plan.freeTornadoExecutionPlan(); - } - } else { - return model.generateTokens( - state, 0, promptTokens, stopTokens, maxTokens, sampler, - false, tokenConsumer); - } - } - - private static String[] ensurePrompt(String[] args) { - for (String arg : args) { - if (arg.equals("--prompt") || arg.equals("-p")) return args; - } - String[] extended = Arrays.copyOf(args, args.length + 2); - extended[args.length] = "--prompt"; - extended[args.length + 1] = USER_PROMPT; - return extended; - } -} diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index 6e5f1683..289b3004 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -83,11 +83,17 @@ public List encodeDialogPrompt(boolean appendAssistantTurn, List XML tags:\n\n" + + "\n" + toolsJson + "\n\n\n" + + "IMPORTANT: the \"name\" field in your tool call MUST be exactly one of the function names " + + "listed inside above — not a path, not a word from the user's message.\n\n" + + "For each function call, return a json object with function name and arguments " + + "within XML tags:\n\n" + + "\n" + + "{\"name\": , \"arguments\": }\n" + + ""; } /** @@ -97,16 +103,9 @@ public String toolSystemPromptSuffix(String toolsJson) { @Override public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); - if (pythonTag != -1) { - tokens.add(pythonTag); - } - String json = "{\"name\":\"" + toolCall.name() + "\",\"parameters\":" + toolCall.argumentsJson() + "}"; + String json = "\n{\"name\":\"" + toolCall.name() + "\",\"arguments\":" + toolCall.argumentsJson() + "}\n"; tokens.addAll(tokenizer.encodeAsList(json)); - if (endOfMessage != -1) { - tokens.add(endOfMessage); - } else { - tokens.add(endOfTurn); - } + tokens.add(endOfTurn); return tokens; } @@ -128,31 +127,13 @@ public List encodeToolResultTurn(String toolCallId, String toolName, St /** * Detects a tool call in the decoded response text. - * - *

Two formats are recognised:

- *
    - *
  1. LLaMA 3.1 native: {@code <|python_tag|>{"name":…,"parameters":{…}}}
  2. - *
  3. Fallback: raw JSON object (possibly wrapped in markdown code fences), - * as produced by smaller models that follow the system-prompt instructions but skip - * the special token prefix.
  4. - *
+ * Supports LLaMA 3.1 (native {@code <|python_tag|>} + {@code "parameters"} key), + * LLaMA 3.2 ({@code "arguments"} key, tag often absent), and a raw-JSON fallback + * for smaller models. Delegates to {@link ToolCallParserUtils#parseLlamaResponse}. */ @Override public Optional extractToolCall(String responseText) { - // ── 1. Native LLaMA 3.1 format: <|python_tag|>{...} ────────────────── - int idx = responseText.indexOf("<|python_tag|>"); - if (idx != -1) { - String json = responseText.substring(idx + "<|python_tag|>".length()).strip(); - return parseToolCallJson(json); - } - - // ── 2. Fallback: raw JSON, possibly inside markdown code fences ─────── - String stripped = stripMarkdownFences(responseText.strip()); - if (stripped.startsWith("{")) { - return parseToolCallJson(stripped); - } - - return Optional.empty(); + return ToolCallParserUtils.parseLlamaResponse(responseText); } /** @@ -167,88 +148,4 @@ public Set getToolAwareStopTokens() { return stopTokens; } - /** - * Strips surrounding markdown code fences (``` or ```json / ```python etc.) if present. - */ - private static String stripMarkdownFences(String text) { - if (!text.startsWith("```")) { - return text; - } - int firstNewline = text.indexOf('\n'); - if (firstNewline == -1) { - return text; - } - String body = text.substring(firstNewline + 1); - if (body.endsWith("```")) { - body = body.substring(0, body.length() - 3).stripTrailing(); - } - return body.strip(); - } - - /** - * Parses a tool call JSON object into a {@link ToolCallExtract}. - * - *

Accepts both formats produced by LLaMA variants:

- *
    - *
  • {@code {"name":"fn","parameters":{…}}} — LLaMA 3.1 native
  • - *
  • {@code {"function":"fn","parameters":{…}}} — produced by some fine-tunes
  • - *
  • {@code {"name":"fn","arguments":{…}}} — alternative key
  • - *
- * Uses brace-counting to extract nested argument objects correctly. - */ - private static Optional parseToolCallJson(String json) { - // ── extract tool name: try "name" then "function" ───────────────────── - String name = extractStringValue(json, "name"); - if (name == null) { - name = extractStringValue(json, "function"); - } - if (name == null) { - return Optional.empty(); - } - - // ── extract arguments object: try "parameters" then "arguments" ─────── - String argsJson = extractNestedObject(json, "parameters"); - if (argsJson == null) { - argsJson = extractNestedObject(json, "arguments"); - } - if (argsJson == null) { - argsJson = "{}"; // tool call with no arguments - } - - return Optional.of(new ToolCallExtract(name, argsJson)); - } - - /** Extracts the string value for {@code "key":""} from a JSON object. */ - private static String extractStringValue(String json, String key) { - String marker = "\"" + key + "\":"; - int markerIdx = json.indexOf(marker); - if (markerIdx == -1) return null; - int quoteStart = json.indexOf('"', markerIdx + marker.length()); - if (quoteStart == -1) return null; - int quoteEnd = json.indexOf('"', quoteStart + 1); - if (quoteEnd == -1) return null; - return json.substring(quoteStart + 1, quoteEnd); - } - - /** - * Extracts the JSON object value for {@code "key":{…}} using brace-counting, - * so nested objects are handled correctly regardless of what follows. - */ - private static String extractNestedObject(String json, String key) { - String marker = "\"" + key + "\":"; - int markerIdx = json.indexOf(marker); - if (markerIdx == -1) return null; - int braceStart = json.indexOf('{', markerIdx + marker.length()); - if (braceStart == -1) return null; - int depth = 0; - for (int i = braceStart; i < json.length(); i++) { - char c = json.charAt(i); - if (c == '{') depth++; - else if (c == '}') { - depth--; - if (depth == 0) return json.substring(braceStart, i + 1); - } - } - return null; // unbalanced JSON - } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index c4e31270..de671a59 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -187,70 +187,10 @@ public List encodeToolResultTurn(String toolCallId, String toolName, St /** * Detects a tool call enclosed in {@code } tags. + * Delegates to {@link ToolCallParserUtils#parseQwen3Response}. */ @Override public Optional extractToolCall(String responseText) { - int start = responseText.indexOf(""); - int end = responseText.lastIndexOf(""); - if (start == -1 || end == -1) { - return Optional.empty(); - } - String json = responseText.substring(start + "".length(), end).strip(); - return parseToolCallJson(json, "arguments"); - } - - /** - * Parses {@code name} and the value of {@code argsKey} out of a tool-call JSON object. - * Uses brace-counting to extract nested argument objects correctly. - * Avoids a JSON-library dependency. - */ - private static Optional parseToolCallJson(String json, String argsKey) { - // extract "name" - String name = extractStringValue(json, "name"); - if (name == null) { - return Optional.empty(); - } - - // extract arguments object using brace-counting - String argsJson = extractNestedObject(json, argsKey); - if (argsJson == null) { - argsJson = "{}"; // tool call with no arguments - } - - return Optional.of(new ToolCallExtract(name, argsJson)); - } - - /** Extracts the string value for {@code "key":""} from a JSON object. */ - private static String extractStringValue(String json, String key) { - String marker = "\"" + key + "\":"; - int markerIdx = json.indexOf(marker); - if (markerIdx == -1) return null; - int quoteStart = json.indexOf('"', markerIdx + marker.length()); - if (quoteStart == -1) return null; - int quoteEnd = json.indexOf('"', quoteStart + 1); - if (quoteEnd == -1) return null; - return json.substring(quoteStart + 1, quoteEnd); - } - - /** - * Extracts the JSON object value for {@code "key":{…}} using brace-counting, - * so nested objects are handled correctly regardless of what follows. - */ - private static String extractNestedObject(String json, String key) { - String marker = "\"" + key + "\":"; - int markerIdx = json.indexOf(marker); - if (markerIdx == -1) return null; - int braceStart = json.indexOf('{', markerIdx + marker.length()); - if (braceStart == -1) return null; - int depth = 0; - for (int i = braceStart; i < json.length(); i++) { - char c = json.charAt(i); - if (c == '{') depth++; - else if (c == '}') { - depth--; - if (depth == 0) return json.substring(braceStart, i + 1); - } - } - return null; // unbalanced JSON + return ToolCallParserUtils.parseQwen3Response(responseText); } } diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java new file mode 100644 index 00000000..73a327a4 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java @@ -0,0 +1,143 @@ +package org.beehive.gpullama3.model.format; + +import java.util.Optional; + +/** + * Pure-string tool-call extraction for Llama and Qwen3 response formats. + * + * All methods are stateless and do not require any model or tokenizer instance, + * making them directly unit-testable. + */ +public final class ToolCallParserUtils { + + private ToolCallParserUtils() {} + + // ── Llama ───────────────────────────────────────────────────────────────── + + /** + * Extracts a tool call from a LLaMA 3.1 or 3.2 model response. + * + * Recognised formats: + * 1. {@code <|python_tag|>{"name":…,"parameters":{…}}} — LLaMA 3.1 native, also accepted by 3.2 + * 2. Raw JSON with {@code "arguments"} key instead of {@code "parameters"} — LLaMA 3.2 instruction format + * 3. Raw JSON object optionally inside markdown code fences — fallback for models that + * follow system-prompt instructions but omit the special-token prefix + * + * Both {@code "parameters"} and {@code "arguments"} are tried so a single implementation + * handles the 3.1 and 3.2 variants transparently. + */ + public static Optional parseLlamaResponse(String responseText) { + // 1. Native LLaMA 3.1 format: <|python_tag|>{...} + int idx = responseText.indexOf("<|python_tag|>"); + if (idx != -1) { + String json = responseText.substring(idx + "<|python_tag|>".length()).strip(); + return parseLlamaJson(json); + } + + // 2. LLaMA 3.2 format: ... + int tcStart = responseText.indexOf(""); + int tcEnd = responseText.lastIndexOf(""); + if (tcStart != -1 && tcEnd != -1 && tcEnd > tcStart) { + String json = responseText.substring(tcStart + "".length(), tcEnd).strip(); + return parseLlamaJson(json); + } + + // 3. Fallback: raw JSON, possibly inside markdown code fences + String stripped = stripMarkdownFences(responseText.strip()); + if (stripped.startsWith("{")) { + return parseLlamaJson(stripped); + } + + return Optional.empty(); + } + + /** + * Parses a LLaMA-style tool call JSON object. + * Accepts {@code {"name":…,"parameters":{…}}}, {@code {"function":…,"parameters":{…}}}, + * and {@code {"name":…,"arguments":{…}}}. + */ + private static Optional parseLlamaJson(String json) { + String name = extractStringValue(json, "name"); + if (name == null) { + name = extractStringValue(json, "function"); + } + if (name == null) return Optional.empty(); + + String argsJson = extractNestedObject(json, "parameters"); + if (argsJson == null) argsJson = extractNestedObject(json, "arguments"); + if (argsJson == null) argsJson = "{}"; + + return Optional.of(new ToolCallExtract(name, argsJson)); + } + + // ── Qwen3 ───────────────────────────────────────────────────────────────── + + /** + * Extracts a tool call enclosed in {@code } tags + * as produced by Qwen3 models. + */ + public static Optional parseQwen3Response(String responseText) { + int start = responseText.indexOf(""); + int end = responseText.lastIndexOf(""); + if (start == -1 || end == -1 || end <= start) return Optional.empty(); + + String json = responseText.substring(start + "".length(), end).strip(); + + String name = extractStringValue(json, "name"); + if (name == null) return Optional.empty(); + + String argsJson = extractNestedObject(json, "arguments"); + if (argsJson == null) argsJson = "{}"; + + return Optional.of(new ToolCallExtract(name, argsJson)); + } + + // ── Shared helpers ──────────────────────────────────────────────────────── + + /** Strips surrounding markdown code fences (```…```) if present. */ + public static String stripMarkdownFences(String text) { + if (!text.startsWith("```")) return text; + int firstNewline = text.indexOf('\n'); + if (firstNewline == -1) return text; + String body = text.substring(firstNewline + 1); + if (body.endsWith("```")) body = body.substring(0, body.length() - 3).stripTrailing(); + return body.strip(); + } + + /** Extracts the string value for {@code "key": ""} from a JSON object. Tolerates whitespace around {@code :}. */ + public static String extractStringValue(String json, String key) { + String marker = "\"" + key + "\""; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int colonIdx = json.indexOf(':', markerIdx + marker.length()); + if (colonIdx == -1) return null; + int quoteStart = json.indexOf('"', colonIdx + 1); + if (quoteStart == -1) return null; + int quoteEnd = json.indexOf('"', quoteStart + 1); + if (quoteEnd == -1) return null; + return json.substring(quoteStart + 1, quoteEnd); + } + + /** + * Extracts the JSON object value for {@code "key": {…}} using brace-counting. + * Handles nested objects and tolerates whitespace around {@code :}. + */ + public static String extractNestedObject(String json, String key) { + String marker = "\"" + key + "\""; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int colonIdx = json.indexOf(':', markerIdx + marker.length()); + if (colonIdx == -1) return null; + int braceStart = json.indexOf('{', colonIdx + 1); + if (braceStart == -1) return null; + int depth = 0; + for (int i = braceStart; i < json.length(); i++) { + char c = json.charAt(i); + if (c == '{') depth++; + else if (c == '}') { + if (--depth == 0) return json.substring(braceStart, i + 1); + } + } + return null; // unbalanced + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java new file mode 100644 index 00000000..50ad5c9d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java @@ -0,0 +1,31 @@ +package org.beehive.gpullama3.tools; + +/** + * Tuning parameters for a {@link ToolCallingSession}. + * + * @param maxTokens max tokens per inference call + * @param maxRoundTrips max tool → result → re-inference cycles (default 1) + * @param maxToolResultChars tool output is truncated to this length before feeding back + * @param verbose print step-by-step output to stdout + * @param useGpu use TornadoVM GPU path for inference + */ +public record ToolCallingOptions( + int maxTokens, + int maxRoundTrips, + int maxToolResultChars, + boolean verbose, + boolean useGpu) { + + public static ToolCallingOptions defaults() { + return new ToolCallingOptions(1024, 1, 2000, true, false); + } + + public static ToolCallingOptions from(org.beehive.gpullama3.Options options) { + return new ToolCallingOptions( + options.maxTokens(), + 1, + 2000, + true, + options.useTornadovm()); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java new file mode 100644 index 00000000..9d1c539f --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java @@ -0,0 +1,29 @@ +package org.beehive.gpullama3.tools; + +import org.beehive.gpullama3.model.format.ToolCallExtract; + +import java.util.List; + +/** + * The outcome of a complete tool-calling session (prompt → [tool round-trips] → answer). + * + * @param finalAnswer the model's final plain-text answer + * @param callsMade tool calls that were extracted and executed (may be empty) + * @param results corresponding tool results (same order as callsMade) + * @param reachedMaxRoundTrips true when the session stopped because maxRoundTrips was hit + */ +public record ToolCallingResult( + String finalAnswer, + List callsMade, + List results, + boolean reachedMaxRoundTrips) { + + public boolean hadToolCalls() { + return !callsMade.isEmpty(); + } + + /** Returns a result representing a plain-text (no-tool) response. */ + public static ToolCallingResult plainText(String answer) { + return new ToolCallingResult(answer, List.of(), List.of(), false); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java new file mode 100644 index 00000000..b56be36c --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java @@ -0,0 +1,179 @@ +package org.beehive.gpullama3.tools; + +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.format.ToolCallExtract; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.IntConsumer; + +/** + * Framework-agnostic orchestrator for the tool-calling loop: + *
+ *   prompt → first generation → extract tool call? → execute → feed result → final answer
+ * 
+ * + * Supports Llama 3.1, Llama 3.2, and Qwen3 via the {@link ChatFormat} abstraction. + * The session is single-use; create a new instance per request. + */ +public class ToolCallingSession { + + private final Model model; + private final Sampler sampler; + private final ToolRegistry registry; + private final ToolCallingOptions options; + + public ToolCallingSession(Model model, Sampler sampler, ToolRegistry registry, ToolCallingOptions options) { + this.model = model; + this.sampler = sampler; + this.registry = registry; + this.options = options; + } + + /** Run with no custom system prompt (the tool definitions become the system message). */ + public ToolCallingResult run(String userPrompt) { + return run(null, userPrompt); + } + + /** + * Run with an optional system prompt prefix. Tool definitions are appended to it. + * + * @param systemPrompt base system prompt, or {@code null} for tools-only + * @param userPrompt the user's request + */ + public ToolCallingResult run(String systemPrompt, String userPrompt) { + ChatFormat chatFormat = model.chatFormat(); + String toolsJson = registry.toToolsJson(); + String toolSuffix = chatFormat.toolSystemPromptSuffix(toolsJson); + + String effectiveSystem = systemPrompt == null + ? toolSuffix.stripLeading() + : systemPrompt + toolSuffix; + + // ── Build initial prompt tokens ─────────────────────────────────────── + List promptTokens = new ArrayList<>(); + if (model.shouldAddBeginOfText()) { + promptTokens.add(chatFormat.getBeginOfText()); + } + if (model.shouldAddSystemPrompt()) { + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.SYSTEM, effectiveSystem))); + } + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.USER, userPrompt))); + promptTokens.addAll(chatFormat.encodeHeader( + new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + Set toolStopTokens = chatFormat.getToolAwareStopTokens(); + + List callsMade = new ArrayList<>(); + List toolResults = new ArrayList<>(); + + State state = model.createNewState(); + TornadoVMMasterPlan plan = options.useGpu() + ? TornadoVMMasterPlan.initializeTornadoVMPlan(state, model) + : null; + + try { + // ── Tool round-trip loop ────────────────────────────────────────────── + for (int round = 0; round < options.maxRoundTrips(); round++) { + log("\n--- First generation (round %d) ---", round + 1); + List responseTokens = generateTokens(state, plan, promptTokens, toolStopTokens); + + // strip trailing stop token + if (!responseTokens.isEmpty() && toolStopTokens.contains(responseTokens.getLast())) { + responseTokens.removeLast(); + } + String rawResponse = model.tokenizer().decode(responseTokens); + + Optional maybeCall = chatFormat.extractToolCall(rawResponse); + if (maybeCall.isEmpty()) { + log("\n--- No tool call detected; returning plain text response ---"); + return new ToolCallingResult(rawResponse, callsMade, toolResults, false); + } + + ToolCallExtract call = maybeCall.get(); + callsMade.add(call); + log("\n[Tool call] %s(%s)", call.name(), call.argumentsJson()); + + ToolResult result = registry.execute(call); + toolResults.add(result); + log("[Tool result] %s", result.isError() ? "ERROR: " + result.error() : truncate(result.resultText())); + + // ── Build continuation tokens ───────────────────────────────────── + String feedbackContent = result.isError() + ? "Tool '" + call.name() + "' failed: " + result.error() + : "Tool '" + call.name() + "' returned:\n" + truncate(result.resultText()); + + promptTokens = new ArrayList<>(promptTokens); + promptTokens.addAll(chatFormat.encodeToolCallAssistantTurn(call)); + promptTokens.addAll(chatFormat.encodeToolResultTurn(null, call.name(), feedbackContent)); + promptTokens.addAll(chatFormat.encodeMessage( + new ChatFormat.Message(ChatFormat.Role.USER, + "Using only the tool result above, answer the user's question in plain text. Do not repeat the raw output."))); + promptTokens.addAll(chatFormat.encodeHeader( + new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + } + + // ── Final answer after all tool round-trips ─────────────────────────── + log("\n--- Final generation ---"); + List finalTokens = generateTokens(state, plan, promptTokens, chatFormat.getStopTokens()); + if (!finalTokens.isEmpty() && chatFormat.getStopTokens().contains(finalTokens.getLast())) { + finalTokens.removeLast(); + } + String finalAnswer = model.tokenizer().decode(finalTokens); + + boolean hitLimit = callsMade.size() >= options.maxRoundTrips() + && chatFormat.extractToolCall(finalAnswer).isPresent(); + + return new ToolCallingResult(finalAnswer, callsMade, toolResults, hitLimit); + + } finally { + if (plan != null) plan.freeTornadoExecutionPlan(); + } + } + + // ── Inference ───────────────────────────────────────────────────────────── + + private List generateTokens(State state, TornadoVMMasterPlan plan, + List prompt, Set stopTokens) { + IntConsumer tokenConsumer = options.verbose() ? this::printToken : null; + + if (options.useGpu()) { + return model.generateTokensGPU( + state, 0, prompt, stopTokens, options.maxTokens(), sampler, + false, tokenConsumer, plan); + } else { + return model.generateTokens( + state, 0, prompt, stopTokens, options.maxTokens(), sampler, + false, tokenConsumer); + } + } + + private void printToken(int token) { + if (model.tokenizer().shouldDisplayToken(token)) { + System.out.print(model.tokenizer().decode(List.of(token))); + System.out.flush(); + } + } + + // ── Helpers ─────────────────────────────────────────────────────────────── + + private String truncate(String text) { + if (text == null) return ""; + if (text.length() <= options.maxToolResultChars()) return text; + return text.substring(0, options.maxToolResultChars()) + "\n... (truncated)"; + } + + private void log(String fmt, Object... args) { + if (options.verbose()) { + System.out.printf((fmt) + "%n", args); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java b/src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java new file mode 100644 index 00000000..a07695c5 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java @@ -0,0 +1,17 @@ +package org.beehive.gpullama3.tools; + +/** + * Framework-agnostic description of a tool available to the model. + * + * @param name unique tool name + * @param description human-readable description used in the model's system prompt + * @param parametersJson JSON Schema object for the tool's parameters, e.g. + * {@code {"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}} + */ +public record ToolDefinition(String name, String description, String parametersJson) { + + /** Convenience factory for a tool with no parameters. */ + public static ToolDefinition noArgs(String name, String description) { + return new ToolDefinition(name, description, "{\"type\":\"object\",\"properties\":{}}"); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java b/src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java new file mode 100644 index 00000000..b8807486 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java @@ -0,0 +1,13 @@ +package org.beehive.gpullama3.tools; + +import org.beehive.gpullama3.model.format.ToolCallExtract; + +/** + * Executes a single tool call and returns the result. + * Implementations are responsible for parsing {@link ToolCallExtract#argumentsJson()} + * and performing the actual action. + */ +@FunctionalInterface +public interface ToolExecutor { + ToolResult execute(ToolCallExtract call); +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java b/src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java new file mode 100644 index 00000000..e31fa67e --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java @@ -0,0 +1,94 @@ +package org.beehive.gpullama3.tools; + +import org.beehive.gpullama3.model.format.ToolCallExtract; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Holds available tool definitions and their executors. + * Registration order is preserved for deterministic JSON output. + */ +public class ToolRegistry { + + private final Map entries = new LinkedHashMap<>(); + + private record Entry(ToolDefinition definition, ToolExecutor executor) {} + + public ToolRegistry register(ToolDefinition definition, ToolExecutor executor) { + entries.put(definition.name(), new Entry(definition, executor)); + return this; + } + + public Optional getDefinition(String name) { + Entry e = entries.get(name); + return e == null ? Optional.empty() : Optional.of(e.definition()); + } + + public Optional getExecutor(String name) { + Entry e = entries.get(name); + return e == null ? Optional.empty() : Optional.of(e.executor()); + } + + public List definitions() { + return entries.values().stream().map(Entry::definition).toList(); + } + + public boolean isEmpty() { + return entries.isEmpty(); + } + + /** + * Executes the named tool, returning a failure result for unknown tools or + * executor exceptions. Never throws. + */ + public ToolResult execute(ToolCallExtract call) { + Optional executor = getExecutor(call.name()); + if (executor.isEmpty()) { + return ToolResult.failure(call.name(), "Unknown tool: " + call.name()); + } + try { + return executor.get().execute(call); + } catch (Exception e) { + return ToolResult.failure(call.name(), "Tool execution failed: " + e.getMessage()); + } + } + + /** + * Serialises all registered tools to the flat JSON array expected by + * {@code LlamaChatFormat.toolSystemPromptSuffix()} and + * {@code Qwen3ChatFormat.toolSystemPromptSuffix()}. + * + * Format: {@code [{"name":…,"description":…,"parameters":{…}}]} + */ + public String toToolsJson() { + List defs = definitions(); + if (defs.isEmpty()) return "[]"; + + StringBuilder sb = new StringBuilder("[\n"); + for (int i = 0; i < defs.size(); i++) { + ToolDefinition d = defs.get(i); + sb.append(" {\n"); + sb.append(" \"name\": \"").append(escapeJson(d.name())).append("\",\n"); + sb.append(" \"description\": \"").append(escapeJson(d.description())).append("\",\n"); + sb.append(" \"parameters\": ").append(d.parametersJson()).append("\n"); + sb.append(" }"); + if (i < defs.size() - 1) sb.append(","); + sb.append("\n"); + } + sb.append("]"); + return sb.toString(); + } + + private static String escapeJson(String s) { + if (s == null) return ""; + return s.replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t"); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolResult.java b/src/main/java/org/beehive/gpullama3/tools/ToolResult.java new file mode 100644 index 00000000..6752b207 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tools/ToolResult.java @@ -0,0 +1,24 @@ +package org.beehive.gpullama3.tools; + +/** + * The result of executing a single tool call. + * + * @param toolName name of the tool that was invoked + * @param resultText the tool's output (may be JSON or plain text) + * @param error non-null when execution failed; contains the error message + */ +public record ToolResult(String toolName, String resultText, String error) { + + /** Returns true when the tool execution failed. */ + public boolean isError() { + return error != null; + } + + public static ToolResult success(String toolName, String resultText) { + return new ToolResult(toolName, resultText, null); + } + + public static ToolResult failure(String toolName, String errorMessage) { + return new ToolResult(toolName, null, errorMessage); + } +} From 37084edfcbb484f79ba3b3894a85dc8d4c6bb53e Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Sat, 18 Apr 2026 15:05:33 +0300 Subject: [PATCH 08/23] [introduce] Add standalone ToolCallingApp for tool-calling functionality in GPULlama3 --- .../org/beehive/gpullama3/ToolCallingApp.java | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 src/main/java/org/beehive/gpullama3/ToolCallingApp.java diff --git a/src/main/java/org/beehive/gpullama3/ToolCallingApp.java b/src/main/java/org/beehive/gpullama3/ToolCallingApp.java new file mode 100644 index 00000000..7729d695 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/ToolCallingApp.java @@ -0,0 +1,102 @@ +package org.beehive.gpullama3; + +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.format.ToolCallExtract; +import org.beehive.gpullama3.model.format.ToolCallParserUtils; +import org.beehive.gpullama3.tools.*; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +import static org.beehive.gpullama3.inference.sampler.Sampler.createSampler; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; + +/** + * Standalone tool-calling entry point for GPULlama3.java. + * + * Uses the same command-line flags as {@link LlamaApp} plus the tool-calling loop + * provided by {@link ToolCallingSession}. A {@code listDirectory} tool is registered + * as a built-in demo; extend {@link #buildRegistry()} to add more tools. + * + *
+ * java @$TORNADOVM_HOME/tornado-argfile \
+ *     --add-modules jdk.incubator.vector --enable-preview \
+ *     -cp gpu-llama3.jar org.beehive.gpullama3.ToolCallingApp \
+ *     --model /path/to/model.gguf \
+ *     --prompt "Show me what is inside /tmp" \
+ *     --use-tornadovm true
+ * 
+ */ +public class ToolCallingApp { + + public static void main(String[] args) throws IOException { + Options options = Options.parseOptions(args); + Model model = loadModel(options); + Sampler sampler = createSampler(model, options); + + ToolRegistry registry = buildRegistry(); + ToolCallingOptions tcOptions = ToolCallingOptions.from(options); + ToolCallingSession session = new ToolCallingSession(model, sampler, registry, tcOptions); + + ToolCallingResult result = session.run(options.systemPrompt(), options.prompt()); + + if (!tcOptions.verbose()) { + // verbose=false means ToolCallingSession did not stream tokens — print the answer now + System.out.println(result.finalAnswer()); + } + } + + // ── Tool registry ───────────────────────────────────────────────────────── + + private static ToolRegistry buildRegistry() { + ToolRegistry registry = new ToolRegistry(); + registry.register(listDirectoryDefinition(), ToolCallingApp::listDirectory); + return registry; + } + + private static ToolDefinition listDirectoryDefinition() { + return new ToolDefinition( + "listDirectory", + "Lists the contents of a directory on the local filesystem. " + + "Returns file names and metadata. " + + "Use this when the user asks what is inside a directory or folder.", + """ + {"type":"object","properties":{"path":{"type":"string",\ + "description":"Absolute path of the directory to list, e.g. /tmp or /home/orion"}},\ + "required":["path"]}"""); + } + + private static ToolResult listDirectory(ToolCallExtract call) { + String path = ToolCallParserUtils.extractStringValue(call.argumentsJson(), "path"); + if (path == null || path.isBlank()) { + return ToolResult.failure("listDirectory", + "Could not extract 'path' from arguments: " + call.argumentsJson()); + } + Path dir = Path.of(path); + if (!dir.isAbsolute()) + return ToolResult.failure("listDirectory", "Path must be absolute. Got: " + path); + if (!Files.exists(dir)) + return ToolResult.failure("listDirectory", "Path does not exist: " + path); + if (!Files.isDirectory(dir)) + return ToolResult.failure("listDirectory", "Not a directory: " + path); + try { + StringBuilder sb = new StringBuilder("Contents of ").append(path).append(":\n"); + try (var stream = Files.list(dir)) { + stream.sorted().forEach(entry -> { + boolean isDir = Files.isDirectory(entry); + long size = 0; + try { if (!isDir) size = Files.size(entry); } catch (IOException ignored) {} + sb.append(isDir ? "dir: " : "file: ") + .append(entry.getFileName()) + .append(isDir ? "" : ", " + size + " bytes") + .append("\n"); + }); + } + return ToolResult.success("listDirectory", sb.toString()); + } catch (IOException e) { + return ToolResult.failure("listDirectory", e.getMessage()); + } + } +} From 6fbe7432033b23cbf049cb567a1a3328e0959da9 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 21 Apr 2026 12:28:05 +0300 Subject: [PATCH 09/23] [tools][wip] Add support for Llama 3.2 tool-call injection: batch tool calls, user message integration, and enhanced response parsing. --- .../gpullama3/model/format/ChatFormat.java | 58 ++++++++++++- .../model/format/LlamaChatFormat.java | 65 ++++++++++---- .../model/format/Qwen3ChatFormat.java | 5 ++ .../model/format/ToolCallParserUtils.java | 84 +++++++++++++++++-- .../gpullama3/tools/ToolCallingSession.java | 73 +++++++++++----- 5 files changed, 236 insertions(+), 49 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index 486cdd54..c1ceb267 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -39,16 +39,54 @@ default ChatTokens chatTokens() { /** * Returns plain text to append to the system message content when tools are available. - * The returned string is concatenated to the system message before encoding, so the - * normal {@link #encodeMessage} path handles tokenization. + * Used by formats that inject tool definitions into the system message. * - * @param toolsJson JSON array of tool definitions, e.g. - * {@code [{"type":"function","function":{...}}]} + *

Formats that inject tools into the user message instead should override + * {@link #injectsToolsInUserMessage()}, {@link #toolSystemMessagePrefix()}, and + * {@link #toolFirstUserMessagePrefix(String)} rather than this method. + * + * @param toolsJson JSON array of tool definitions */ default String toolSystemPromptSuffix(String toolsJson) { throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); } + /** + * Returns {@code true} when this format injects tool definitions into the + * first user message instead of the system message. + * + *

When this returns {@code true}, callers should: + *

    + *
  1. Prepend {@link #toolSystemMessagePrefix()} to the system message content.
  2. + *
  3. Prepend {@link #toolFirstUserMessagePrefix(String)} to the first user message.
  4. + *
+ * When {@code false} (default), callers should append {@link #toolSystemPromptSuffix} to + * the system message as before. + */ + default boolean injectsToolsInUserMessage() { + return false; + } + + /** + * Returns text to prepend to the system message content when tools are active + * and {@link #injectsToolsInUserMessage()} is {@code true}. + * Default: empty string (no prefix). + */ + default String toolSystemMessagePrefix() { + return ""; + } + + /** + * Returns the preamble to prepend to the first user message when + * {@link #injectsToolsInUserMessage()} is {@code true}. + * The preamble should include the tool definitions and usage instructions. + * + * @param toolsJson JSON array of tool definitions + */ + default String toolFirstUserMessagePrefix(String toolsJson) { + return ""; + } + /** * Re-encodes a prior assistant tool-call turn into the conversation token stream. * Used when replaying multi-turn history that contains a previous tool call. @@ -80,6 +118,18 @@ default Optional extractToolCall(String responseText) { return Optional.empty(); } + /** + * Extracts ALL tool calls from a response. Models may emit multiple + * {@code } blocks in a single turn (batch tool calls). + * The default delegates to {@link #extractToolCall} for formats that + * do not support batch calls. + * + * @param responseText the fully decoded response from the model + */ + default List extractAllToolCalls(String responseText) { + return extractToolCall(responseText).map(List::of).orElse(List.of()); + } + /** * Stop tokens to use when tool calling is enabled. * Some models (LLaMA 3.1+) use a different end-of-turn token ({@code <|eom_id|>}) diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index 289b3004..6f2e6ffe 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -78,32 +78,58 @@ public List encodeDialogPrompt(boolean appendAssistantTurn, Listfirst user message + * (the GGUF-embedded chat template has {@code tools_in_user_message = true} by default). + * The system message receives only an environment prefix; the tools and usage instructions + * go in the user turn. */ @Override - public String toolSystemPromptSuffix(String toolsJson) { - return "\n\n# Tools\n\n" - + "You may call one or more functions to assist with the user query.\n\n" - + "You are provided with function signatures within XML tags:\n\n" - + "\n" + toolsJson + "\n\n\n" - + "IMPORTANT: the \"name\" field in your tool call MUST be exactly one of the function names " - + "listed inside above — not a path, not a word from the user's message.\n\n" - + "For each function call, return a json object with function name and arguments " - + "within XML tags:\n\n" - + "\n" - + "{\"name\": , \"arguments\": }\n" - + ""; + public boolean injectsToolsInUserMessage() { + return true; } /** - * Re-encodes a prior assistant tool-call turn for multi-turn history. - * Format: {@code <|start_header_id|>assistant<|end_header_id|>\n<|python_tag|>JSON<|eom_id|>} + * System-message prefix that signals tool availability to Llama 3.2. + * Matches the template's {@code "Environment: ipython\n"} line. + */ + @Override + public String toolSystemMessagePrefix() { + return "Environment: ipython\n\n"; + } + + /** + * Prepends tool definitions and usage instructions to the first user message, + * matching the Llama 3.2 GGUF chat template ({@code tools_in_user_message = true}). + * + *

Format mirrors: + *

+     * Given the following functions, please respond with a JSON for a function call
+     * with its proper arguments that best answers the given prompt.
+     *
+     * Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.
+     * Do not use variables.
+     *
+     * {toolsJson}
+     *
+     * 
+ */ + @Override + public String toolFirstUserMessagePrefix(String toolsJson) { + return "Given the following functions, please respond with a JSON for a function call " + + "with its proper arguments that best answers the given prompt.\n\n" + + "Respond in the format {\"name\": function name, \"parameters\": dictionary of " + + "argument name and its value}. Do not use variables.\n\n" + + toolsJson + "\n\n"; + } + + /** + * Re-encodes a prior assistant tool-call turn for multi-turn history using the + * Llama 3.2 native JSON format: {@code {"name":"…","parameters":{…}}<|eot_id|>}. */ @Override public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); - String json = "\n{\"name\":\"" + toolCall.name() + "\",\"arguments\":" + toolCall.argumentsJson() + "}\n"; + String json = "{\"name\": \"" + toolCall.name() + "\", \"parameters\": " + toolCall.argumentsJson() + "}"; tokens.addAll(tokenizer.encodeAsList(json)); tokens.add(endOfTurn); return tokens; @@ -136,6 +162,11 @@ public Optional extractToolCall(String responseText) { return ToolCallParserUtils.parseLlamaResponse(responseText); } + @Override + public List extractAllToolCalls(String responseText) { + return ToolCallParserUtils.parseAllToolCalls(responseText); + } + /** * Adds {@code <|eom_id|>} to the stop tokens when tools are enabled. * LLaMA 3.1 ends tool-call turns with {@code <|eom_id|>} instead of {@code <|eot_id|>}. diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index de671a59..35930ff1 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -193,4 +193,9 @@ public List encodeToolResultTurn(String toolCallId, String toolName, St public Optional extractToolCall(String responseText) { return ToolCallParserUtils.parseQwen3Response(responseText); } + + @Override + public List extractAllToolCalls(String responseText) { + return ToolCallParserUtils.parseAllToolCalls(responseText); + } } diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java index 73a327a4..146148fc 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java @@ -1,5 +1,7 @@ package org.beehive.gpullama3.model.format; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; /** @@ -41,6 +43,11 @@ public static Optional parseLlamaResponse(String responseText) String json = responseText.substring(tcStart + "".length(), tcEnd).strip(); return parseLlamaJson(json); } + // 2b. Unclosed — model stopped (eot_id / eom_id) before writing the closing tag + if (tcStart != -1 && tcEnd == -1) { + String json = responseText.substring(tcStart + "".length()).strip(); + return parseLlamaJson(json); + } // 3. Fallback: raw JSON, possibly inside markdown code fences String stripped = stripMarkdownFences(responseText.strip()); @@ -72,6 +79,54 @@ private static Optional parseLlamaJson(String json) { // ── Qwen3 ───────────────────────────────────────────────────────────────── + /** + * Extracts ALL tool calls from a response that may contain multiple + * {@code } blocks (Llama 3.2 and Qwen3 batch calls). + * + * Falls back to the raw-JSON single-call path if no tags are found. + * Returns an empty list when the response contains no tool calls. + */ + public static List parseAllToolCalls(String responseText) { + List calls = new java.util.ArrayList<>(); + + // <|python_tag|> (Llama 3.1) — single call by definition + int pythonIdx = responseText.indexOf("<|python_tag|>"); + if (pythonIdx != -1) { + parseLlamaJson(responseText.substring(pythonIdx + "<|python_tag|>".length()).strip()) + .ifPresent(calls::add); + return calls; + } + + // Scan for all blocks + int searchFrom = 0; + while (true) { + int start = responseText.indexOf("", searchFrom); + if (start == -1) break; + int end = responseText.indexOf("", start); + String json; + if (end != -1) { + json = responseText.substring(start + "".length(), end).strip(); + searchFrom = end + "".length(); + } else { + // Unclosed tag — model stopped before writing the closing tag + json = responseText.substring(start + "".length()).strip(); + searchFrom = responseText.length(); + } + parseLlamaJson(json).ifPresent(calls::add); + if (end == -1) break; + } + + // Raw JSON fallback (no tags at all) + if (calls.isEmpty()) { + String stripped = stripMarkdownFences(responseText.strip()); + if (stripped.startsWith("{")) { + parseLlamaJson(stripped).ifPresent(calls::add); + } + } + + return calls; + } + /** * Extracts a tool call enclosed in {@code } tags * as produced by Qwen3 models. @@ -79,9 +134,11 @@ private static Optional parseLlamaJson(String json) { public static Optional parseQwen3Response(String responseText) { int start = responseText.indexOf(""); int end = responseText.lastIndexOf(""); - if (start == -1 || end == -1 || end <= start) return Optional.empty(); + if (start == -1) return Optional.empty(); - String json = responseText.substring(start + "".length(), end).strip(); + String json = (end != -1 && end > start) + ? responseText.substring(start + "".length(), end).strip() + : responseText.substring(start + "".length()).strip(); String name = extractStringValue(json, "name"); if (name == null) return Optional.empty(); @@ -104,7 +161,11 @@ public static String stripMarkdownFences(String text) { return body.strip(); } - /** Extracts the string value for {@code "key": ""} from a JSON object. Tolerates whitespace around {@code :}. */ + /** + * Extracts the string value for {@code "key": ""} from a JSON object. + * Tolerates whitespace around {@code :} and correctly skips escaped quotes ({@code \"}) + * inside the value, so multi-line code strings with embedded {@code "} are returned intact. + */ public static String extractStringValue(String json, String key) { String marker = "\"" + key + "\""; int markerIdx = json.indexOf(marker); @@ -113,9 +174,20 @@ public static String extractStringValue(String json, String key) { if (colonIdx == -1) return null; int quoteStart = json.indexOf('"', colonIdx + 1); if (quoteStart == -1) return null; - int quoteEnd = json.indexOf('"', quoteStart + 1); - if (quoteEnd == -1) return null; - return json.substring(quoteStart + 1, quoteEnd); + // Scan for the closing quote, honouring backslash escapes + int i = quoteStart + 1; + while (i < json.length()) { + char c = json.charAt(i); + if (c == '\\') { + i += 2; // skip escape sequence (e.g. \", \\, \n) + } else if (c == '"') { + break; + } else { + i++; + } + } + if (i >= json.length()) return null; + return json.substring(quoteStart + 1, i); } /** diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java index b56be36c..565294db 100644 --- a/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java +++ b/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java @@ -50,23 +50,45 @@ public ToolCallingResult run(String userPrompt) { public ToolCallingResult run(String systemPrompt, String userPrompt) { ChatFormat chatFormat = model.chatFormat(); String toolsJson = registry.toToolsJson(); - String toolSuffix = chatFormat.toolSystemPromptSuffix(toolsJson); - String effectiveSystem = systemPrompt == null - ? toolSuffix.stripLeading() - : systemPrompt + toolSuffix; + // Build effective system and user messages according to the model's tool injection strategy. + // Formats like Llama 3.2 inject tool definitions into the first user message + // (injectsToolsInUserMessage() == true); others (Qwen3, Mistral) append to the system message. + String effectiveSystem; + String effectiveUser; + if (chatFormat.injectsToolsInUserMessage()) { + String prefix = chatFormat.toolSystemMessagePrefix(); + effectiveSystem = systemPrompt == null + ? prefix.stripLeading() + : prefix + systemPrompt; + effectiveUser = chatFormat.toolFirstUserMessagePrefix(toolsJson) + userPrompt; + } else { + String toolSuffix = chatFormat.toolSystemPromptSuffix(toolsJson); + effectiveSystem = systemPrompt == null + ? toolSuffix.stripLeading() + : systemPrompt + toolSuffix; + effectiveUser = userPrompt; + } + + log("\n[DEBUG] model: %s", model.getClass().getSimpleName()); + log("[DEBUG] chatFormat: %s", chatFormat.getClass().getSimpleName()); + log("[DEBUG] shouldAddSystemPrompt: %s", model.shouldAddSystemPrompt()); + log("[DEBUG] toolAwareStopTokens: %s", chatFormat.getToolAwareStopTokens()); + log("[DEBUG] ── effective system prompt ──────────────────────────"); + log("%s", effectiveSystem); + log("[DEBUG] ── end system prompt ─────────────────────────────────"); // ── Build initial prompt tokens ─────────────────────────────────────── List promptTokens = new ArrayList<>(); if (model.shouldAddBeginOfText()) { promptTokens.add(chatFormat.getBeginOfText()); } - if (model.shouldAddSystemPrompt()) { + if (model.shouldAddSystemPrompt() && !effectiveSystem.isBlank()) { promptTokens.addAll(chatFormat.encodeMessage( new ChatFormat.Message(ChatFormat.Role.SYSTEM, effectiveSystem))); } promptTokens.addAll(chatFormat.encodeMessage( - new ChatFormat.Message(ChatFormat.Role.USER, userPrompt))); + new ChatFormat.Message(ChatFormat.Role.USER, effectiveUser))); promptTokens.addAll(chatFormat.encodeHeader( new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); @@ -92,31 +114,38 @@ public ToolCallingResult run(String systemPrompt, String userPrompt) { } String rawResponse = model.tokenizer().decode(responseTokens); - Optional maybeCall = chatFormat.extractToolCall(rawResponse); - if (maybeCall.isEmpty()) { + log("[DEBUG] raw response (round %d): >>>%s<<<", round + 1, rawResponse); + log("[DEBUG] response tokens: %d (stop token stripped: %s)", + responseTokens.size(), + !responseTokens.isEmpty() && toolStopTokens.contains(responseTokens.getLast()) ? "yes" : "no — stop token was removed before decoding"); + + List batchCalls = chatFormat.extractAllToolCalls(rawResponse); + log("[DEBUG] extractAllToolCalls found: %d call(s)", batchCalls.size()); + if (batchCalls.isEmpty()) { log("\n--- No tool call detected; returning plain text response ---"); return new ToolCallingResult(rawResponse, callsMade, toolResults, false); } - ToolCallExtract call = maybeCall.get(); - callsMade.add(call); - log("\n[Tool call] %s(%s)", call.name(), call.argumentsJson()); + // ── Execute all tool calls found in this response ───────────────── + promptTokens = new ArrayList<>(promptTokens); + for (ToolCallExtract call : batchCalls) { + callsMade.add(call); + log("\n[Tool call] %s(%s)", call.name(), call.argumentsJson()); - ToolResult result = registry.execute(call); - toolResults.add(result); - log("[Tool result] %s", result.isError() ? "ERROR: " + result.error() : truncate(result.resultText())); + ToolResult result = registry.execute(call); + toolResults.add(result); + log("[Tool result] %s", result.isError() ? "ERROR: " + result.error() : truncate(result.resultText())); - // ── Build continuation tokens ───────────────────────────────────── - String feedbackContent = result.isError() - ? "Tool '" + call.name() + "' failed: " + result.error() - : "Tool '" + call.name() + "' returned:\n" + truncate(result.resultText()); + String feedbackContent = result.isError() + ? "Tool '" + call.name() + "' failed: " + result.error() + : "Tool '" + call.name() + "' returned:\n" + truncate(result.resultText()); - promptTokens = new ArrayList<>(promptTokens); - promptTokens.addAll(chatFormat.encodeToolCallAssistantTurn(call)); - promptTokens.addAll(chatFormat.encodeToolResultTurn(null, call.name(), feedbackContent)); + promptTokens.addAll(chatFormat.encodeToolCallAssistantTurn(call)); + promptTokens.addAll(chatFormat.encodeToolResultTurn(null, call.name(), feedbackContent)); + } promptTokens.addAll(chatFormat.encodeMessage( new ChatFormat.Message(ChatFormat.Role.USER, - "Using only the tool result above, answer the user's question in plain text. Do not repeat the raw output."))); + "Using the tool results above, call the next required tool or answer the user's original question if all steps are complete."))); promptTokens.addAll(chatFormat.encodeHeader( new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); } From 0881f950d03d4628200bce14859b1ea6bf004000 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 28 Apr 2026 14:11:28 +0300 Subject: [PATCH 10/23] [tools][wip] Add fix for tool-calling --- .../beehive/gpullama3/model/format/LlamaChatFormat.java | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index 6f2e6ffe..421a5513 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -129,9 +129,15 @@ public String toolFirstUserMessagePrefix(String toolsJson) { @Override public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); + // Preserve the <|python_tag|> prefix used by LLaMA 3.1/3.2 for tool calls so that + // replayed history looks identical to what the model originally generated. + if (pythonTag != -1) { + tokens.add(pythonTag); + } String json = "{\"name\": \"" + toolCall.name() + "\", \"parameters\": " + toolCall.argumentsJson() + "}"; tokens.addAll(tokenizer.encodeAsList(json)); - tokens.add(endOfTurn); + // LLaMA 3.1 ends tool-call turns with <|eom_id|>; fall back to <|eot_id|> for 3.2. + tokens.add(endOfMessage != -1 ? endOfMessage : endOfTurn); return tokens; } From 5033dfe8d40d0917e9cb67692b952b0def84eec1 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 15 May 2026 09:18:22 +0300 Subject: [PATCH 11/23] [tools] Remove standalone `ToolCallingApp` and its references as it was used only for testing --- llama-tornado | 15 +-- llamaTornado | 15 +-- .../org/beehive/gpullama3/ToolCallingApp.java | 102 ------------------ 3 files changed, 4 insertions(+), 128 deletions(-) delete mode 100644 src/main/java/org/beehive/gpullama3/ToolCallingApp.java diff --git a/llama-tornado b/llama-tornado index 61439dad..1d6c3d23 100755 --- a/llama-tornado +++ b/llama-tornado @@ -191,7 +191,7 @@ class LlamaRunner: [ "-cp", self._find_llama_jar(), - args.main_class, + "org.beehive.gpullama3.LlamaApp", ] ) cmd.extend(module_config) @@ -246,9 +246,6 @@ class LlamaRunner: elif args.instruct: llama_args.append("--instruct") - if args.tool_demo: - llama_args.append("--tool-demo") - return cmd + llama_args def run(self, args: argparse.Namespace) -> int: @@ -530,16 +527,6 @@ def create_parser() -> argparse.ArgumentParser: # Advanced options advanced_group = parser.add_argument_group("Advanced Options") - advanced_group.add_argument( - "--main-class", - default="org.beehive.gpullama3.cli.LlamaTornadoCli", - help="Java main class to run (default: LlamaTornadoCli)", - ) - advanced_group.add_argument( - "--tool-demo", - action="store_true", - help="Run the tool calling demo (requires a LLaMA 3.1 or Qwen3 model)", - ) advanced_group.add_argument( "--opencl-flags", default="-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only", diff --git a/llamaTornado b/llamaTornado index 869ddb7c..068c7946 100755 --- a/llamaTornado +++ b/llamaTornado @@ -17,8 +17,7 @@ record Config( boolean printBytecodes, boolean threads, boolean printKernel, boolean fullDump, boolean verboseInit, boolean showCommand, boolean executeAfterShow, - String openclFlags, int maxWaitEvents, boolean verbose, - boolean toolDemo + String openclFlags, int maxWaitEvents, boolean verbose ) {} Config parseArgs(String[] args) { @@ -52,7 +51,6 @@ Config parseArgs(String[] args) { String openclFlags = "-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only"; int maxWaitEvents = 32000; boolean verbose = false; - boolean toolDemo = false; for (int i = 0; i < args.length; i++) { switch (args[i]) { @@ -88,7 +86,6 @@ Config parseArgs(String[] args) { case "--opencl-flags" -> openclFlags = args[++i]; case "--max-wait-events" -> maxWaitEvents = Integer.parseInt(args[++i]); case "--verbose", "-v" -> verbose = true; - case "--tool-demo" -> toolDemo = true; default -> { System.err.println("Unknown option: " + args[i]); System.exit(1); @@ -114,8 +111,7 @@ Config parseArgs(String[] args) { return new Config(modelPath, prompt, systemPrompt, temperature, topP, seed, maxTokens, stream, echo, interactive, instruct, useGpu, backend, gpuMemory, heapMin, heapMax, directMemory, debug, profiler, profilerDumpDir, printBytecodes, threads, printKernel, fullDump, - verboseInit, showCommand, executeAfterShow, openclFlags, maxWaitEvents, verbose, - toolDemo); + verboseInit, showCommand, executeAfterShow, openclFlags, maxWaitEvents, verbose); } String parseAndScale(String memoryValue, int multiplier) { @@ -173,7 +169,6 @@ void printUsage() { --show-command Display the full Java command --execute-after-show Execute after showing command --verbose, -v Verbose output - --tool-demo Run the tool calling demo -help Show this help -version Show version @@ -286,10 +281,7 @@ List buildCommand(Config cfg, String javaHome, String tornadoSdk, String } } - var mainClass = cfg.toolDemo() - ? "org.beehive.gpullama3.ToolCallingDemo" - : "org.beehive.gpullama3.LlamaApp"; - cmd.addAll(List.of("-cp", findLlamaJar(llamaRoot), mainClass)); + cmd.addAll(List.of("-cp", findLlamaJar(llamaRoot), "org.beehive.gpullama3.LlamaApp")); // LLaMA arguments cmd.addAll(List.of( @@ -306,7 +298,6 @@ List buildCommand(Config cfg, String javaHome, String tornadoSdk, String if (cfg.systemPrompt() != null) cmd.addAll(List.of("-sp", cfg.systemPrompt())); if (cfg.interactive()) cmd.add("--interactive"); else if (cfg.instruct()) cmd.add("--instruct"); - // --tool-demo is handled by main class selection above, not passed to Java return cmd; } diff --git a/src/main/java/org/beehive/gpullama3/ToolCallingApp.java b/src/main/java/org/beehive/gpullama3/ToolCallingApp.java deleted file mode 100644 index 7729d695..00000000 --- a/src/main/java/org/beehive/gpullama3/ToolCallingApp.java +++ /dev/null @@ -1,102 +0,0 @@ -package org.beehive.gpullama3; - -import org.beehive.gpullama3.inference.sampler.Sampler; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.format.ToolCallExtract; -import org.beehive.gpullama3.model.format.ToolCallParserUtils; -import org.beehive.gpullama3.tools.*; - -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; - -import static org.beehive.gpullama3.inference.sampler.Sampler.createSampler; -import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; - -/** - * Standalone tool-calling entry point for GPULlama3.java. - * - * Uses the same command-line flags as {@link LlamaApp} plus the tool-calling loop - * provided by {@link ToolCallingSession}. A {@code listDirectory} tool is registered - * as a built-in demo; extend {@link #buildRegistry()} to add more tools. - * - *
- * java @$TORNADOVM_HOME/tornado-argfile \
- *     --add-modules jdk.incubator.vector --enable-preview \
- *     -cp gpu-llama3.jar org.beehive.gpullama3.ToolCallingApp \
- *     --model /path/to/model.gguf \
- *     --prompt "Show me what is inside /tmp" \
- *     --use-tornadovm true
- * 
- */ -public class ToolCallingApp { - - public static void main(String[] args) throws IOException { - Options options = Options.parseOptions(args); - Model model = loadModel(options); - Sampler sampler = createSampler(model, options); - - ToolRegistry registry = buildRegistry(); - ToolCallingOptions tcOptions = ToolCallingOptions.from(options); - ToolCallingSession session = new ToolCallingSession(model, sampler, registry, tcOptions); - - ToolCallingResult result = session.run(options.systemPrompt(), options.prompt()); - - if (!tcOptions.verbose()) { - // verbose=false means ToolCallingSession did not stream tokens — print the answer now - System.out.println(result.finalAnswer()); - } - } - - // ── Tool registry ───────────────────────────────────────────────────────── - - private static ToolRegistry buildRegistry() { - ToolRegistry registry = new ToolRegistry(); - registry.register(listDirectoryDefinition(), ToolCallingApp::listDirectory); - return registry; - } - - private static ToolDefinition listDirectoryDefinition() { - return new ToolDefinition( - "listDirectory", - "Lists the contents of a directory on the local filesystem. " + - "Returns file names and metadata. " + - "Use this when the user asks what is inside a directory or folder.", - """ - {"type":"object","properties":{"path":{"type":"string",\ - "description":"Absolute path of the directory to list, e.g. /tmp or /home/orion"}},\ - "required":["path"]}"""); - } - - private static ToolResult listDirectory(ToolCallExtract call) { - String path = ToolCallParserUtils.extractStringValue(call.argumentsJson(), "path"); - if (path == null || path.isBlank()) { - return ToolResult.failure("listDirectory", - "Could not extract 'path' from arguments: " + call.argumentsJson()); - } - Path dir = Path.of(path); - if (!dir.isAbsolute()) - return ToolResult.failure("listDirectory", "Path must be absolute. Got: " + path); - if (!Files.exists(dir)) - return ToolResult.failure("listDirectory", "Path does not exist: " + path); - if (!Files.isDirectory(dir)) - return ToolResult.failure("listDirectory", "Not a directory: " + path); - try { - StringBuilder sb = new StringBuilder("Contents of ").append(path).append(":\n"); - try (var stream = Files.list(dir)) { - stream.sorted().forEach(entry -> { - boolean isDir = Files.isDirectory(entry); - long size = 0; - try { if (!isDir) size = Files.size(entry); } catch (IOException ignored) {} - sb.append(isDir ? "dir: " : "file: ") - .append(entry.getFileName()) - .append(isDir ? "" : ", " + size + " bytes") - .append("\n"); - }); - } - return ToolResult.success("listDirectory", sb.toString()); - } catch (IOException e) { - return ToolResult.failure("listDirectory", e.getMessage()); - } - } -} From 5f05176a0bfd213beb3b78944caec86fb8cd44f0 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 15 May 2026 12:04:05 +0300 Subject: [PATCH 12/23] [tools] Remove tool-calling classes as redundant after dropping stand-alone test-only tool calling app class --- .../gpullama3/tools/ToolCallingOptions.java | 31 --- .../gpullama3/tools/ToolCallingResult.java | 29 --- .../gpullama3/tools/ToolCallingSession.java | 208 ------------------ .../gpullama3/tools/ToolDefinition.java | 17 -- .../beehive/gpullama3/tools/ToolExecutor.java | 13 -- .../beehive/gpullama3/tools/ToolRegistry.java | 94 -------- .../beehive/gpullama3/tools/ToolResult.java | 24 -- 7 files changed, 416 deletions(-) delete mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java delete mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java delete mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java delete mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java delete mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java delete mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java delete mode 100644 src/main/java/org/beehive/gpullama3/tools/ToolResult.java diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java deleted file mode 100644 index 50ad5c9d..00000000 --- a/src/main/java/org/beehive/gpullama3/tools/ToolCallingOptions.java +++ /dev/null @@ -1,31 +0,0 @@ -package org.beehive.gpullama3.tools; - -/** - * Tuning parameters for a {@link ToolCallingSession}. - * - * @param maxTokens max tokens per inference call - * @param maxRoundTrips max tool → result → re-inference cycles (default 1) - * @param maxToolResultChars tool output is truncated to this length before feeding back - * @param verbose print step-by-step output to stdout - * @param useGpu use TornadoVM GPU path for inference - */ -public record ToolCallingOptions( - int maxTokens, - int maxRoundTrips, - int maxToolResultChars, - boolean verbose, - boolean useGpu) { - - public static ToolCallingOptions defaults() { - return new ToolCallingOptions(1024, 1, 2000, true, false); - } - - public static ToolCallingOptions from(org.beehive.gpullama3.Options options) { - return new ToolCallingOptions( - options.maxTokens(), - 1, - 2000, - true, - options.useTornadovm()); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java deleted file mode 100644 index 9d1c539f..00000000 --- a/src/main/java/org/beehive/gpullama3/tools/ToolCallingResult.java +++ /dev/null @@ -1,29 +0,0 @@ -package org.beehive.gpullama3.tools; - -import org.beehive.gpullama3.model.format.ToolCallExtract; - -import java.util.List; - -/** - * The outcome of a complete tool-calling session (prompt → [tool round-trips] → answer). - * - * @param finalAnswer the model's final plain-text answer - * @param callsMade tool calls that were extracted and executed (may be empty) - * @param results corresponding tool results (same order as callsMade) - * @param reachedMaxRoundTrips true when the session stopped because maxRoundTrips was hit - */ -public record ToolCallingResult( - String finalAnswer, - List callsMade, - List results, - boolean reachedMaxRoundTrips) { - - public boolean hadToolCalls() { - return !callsMade.isEmpty(); - } - - /** Returns a result representing a plain-text (no-tool) response. */ - public static ToolCallingResult plainText(String answer) { - return new ToolCallingResult(answer, List.of(), List.of(), false); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java b/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java deleted file mode 100644 index 565294db..00000000 --- a/src/main/java/org/beehive/gpullama3/tools/ToolCallingSession.java +++ /dev/null @@ -1,208 +0,0 @@ -package org.beehive.gpullama3.tools; - -import org.beehive.gpullama3.inference.sampler.Sampler; -import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.format.ChatFormat; -import org.beehive.gpullama3.model.format.ToolCallExtract; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; - -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.function.IntConsumer; - -/** - * Framework-agnostic orchestrator for the tool-calling loop: - *
- *   prompt → first generation → extract tool call? → execute → feed result → final answer
- * 
- * - * Supports Llama 3.1, Llama 3.2, and Qwen3 via the {@link ChatFormat} abstraction. - * The session is single-use; create a new instance per request. - */ -public class ToolCallingSession { - - private final Model model; - private final Sampler sampler; - private final ToolRegistry registry; - private final ToolCallingOptions options; - - public ToolCallingSession(Model model, Sampler sampler, ToolRegistry registry, ToolCallingOptions options) { - this.model = model; - this.sampler = sampler; - this.registry = registry; - this.options = options; - } - - /** Run with no custom system prompt (the tool definitions become the system message). */ - public ToolCallingResult run(String userPrompt) { - return run(null, userPrompt); - } - - /** - * Run with an optional system prompt prefix. Tool definitions are appended to it. - * - * @param systemPrompt base system prompt, or {@code null} for tools-only - * @param userPrompt the user's request - */ - public ToolCallingResult run(String systemPrompt, String userPrompt) { - ChatFormat chatFormat = model.chatFormat(); - String toolsJson = registry.toToolsJson(); - - // Build effective system and user messages according to the model's tool injection strategy. - // Formats like Llama 3.2 inject tool definitions into the first user message - // (injectsToolsInUserMessage() == true); others (Qwen3, Mistral) append to the system message. - String effectiveSystem; - String effectiveUser; - if (chatFormat.injectsToolsInUserMessage()) { - String prefix = chatFormat.toolSystemMessagePrefix(); - effectiveSystem = systemPrompt == null - ? prefix.stripLeading() - : prefix + systemPrompt; - effectiveUser = chatFormat.toolFirstUserMessagePrefix(toolsJson) + userPrompt; - } else { - String toolSuffix = chatFormat.toolSystemPromptSuffix(toolsJson); - effectiveSystem = systemPrompt == null - ? toolSuffix.stripLeading() - : systemPrompt + toolSuffix; - effectiveUser = userPrompt; - } - - log("\n[DEBUG] model: %s", model.getClass().getSimpleName()); - log("[DEBUG] chatFormat: %s", chatFormat.getClass().getSimpleName()); - log("[DEBUG] shouldAddSystemPrompt: %s", model.shouldAddSystemPrompt()); - log("[DEBUG] toolAwareStopTokens: %s", chatFormat.getToolAwareStopTokens()); - log("[DEBUG] ── effective system prompt ──────────────────────────"); - log("%s", effectiveSystem); - log("[DEBUG] ── end system prompt ─────────────────────────────────"); - - // ── Build initial prompt tokens ─────────────────────────────────────── - List promptTokens = new ArrayList<>(); - if (model.shouldAddBeginOfText()) { - promptTokens.add(chatFormat.getBeginOfText()); - } - if (model.shouldAddSystemPrompt() && !effectiveSystem.isBlank()) { - promptTokens.addAll(chatFormat.encodeMessage( - new ChatFormat.Message(ChatFormat.Role.SYSTEM, effectiveSystem))); - } - promptTokens.addAll(chatFormat.encodeMessage( - new ChatFormat.Message(ChatFormat.Role.USER, effectiveUser))); - promptTokens.addAll(chatFormat.encodeHeader( - new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); - - Set toolStopTokens = chatFormat.getToolAwareStopTokens(); - - List callsMade = new ArrayList<>(); - List toolResults = new ArrayList<>(); - - State state = model.createNewState(); - TornadoVMMasterPlan plan = options.useGpu() - ? TornadoVMMasterPlan.initializeTornadoVMPlan(state, model) - : null; - - try { - // ── Tool round-trip loop ────────────────────────────────────────────── - for (int round = 0; round < options.maxRoundTrips(); round++) { - log("\n--- First generation (round %d) ---", round + 1); - List responseTokens = generateTokens(state, plan, promptTokens, toolStopTokens); - - // strip trailing stop token - if (!responseTokens.isEmpty() && toolStopTokens.contains(responseTokens.getLast())) { - responseTokens.removeLast(); - } - String rawResponse = model.tokenizer().decode(responseTokens); - - log("[DEBUG] raw response (round %d): >>>%s<<<", round + 1, rawResponse); - log("[DEBUG] response tokens: %d (stop token stripped: %s)", - responseTokens.size(), - !responseTokens.isEmpty() && toolStopTokens.contains(responseTokens.getLast()) ? "yes" : "no — stop token was removed before decoding"); - - List batchCalls = chatFormat.extractAllToolCalls(rawResponse); - log("[DEBUG] extractAllToolCalls found: %d call(s)", batchCalls.size()); - if (batchCalls.isEmpty()) { - log("\n--- No tool call detected; returning plain text response ---"); - return new ToolCallingResult(rawResponse, callsMade, toolResults, false); - } - - // ── Execute all tool calls found in this response ───────────────── - promptTokens = new ArrayList<>(promptTokens); - for (ToolCallExtract call : batchCalls) { - callsMade.add(call); - log("\n[Tool call] %s(%s)", call.name(), call.argumentsJson()); - - ToolResult result = registry.execute(call); - toolResults.add(result); - log("[Tool result] %s", result.isError() ? "ERROR: " + result.error() : truncate(result.resultText())); - - String feedbackContent = result.isError() - ? "Tool '" + call.name() + "' failed: " + result.error() - : "Tool '" + call.name() + "' returned:\n" + truncate(result.resultText()); - - promptTokens.addAll(chatFormat.encodeToolCallAssistantTurn(call)); - promptTokens.addAll(chatFormat.encodeToolResultTurn(null, call.name(), feedbackContent)); - } - promptTokens.addAll(chatFormat.encodeMessage( - new ChatFormat.Message(ChatFormat.Role.USER, - "Using the tool results above, call the next required tool or answer the user's original question if all steps are complete."))); - promptTokens.addAll(chatFormat.encodeHeader( - new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); - } - - // ── Final answer after all tool round-trips ─────────────────────────── - log("\n--- Final generation ---"); - List finalTokens = generateTokens(state, plan, promptTokens, chatFormat.getStopTokens()); - if (!finalTokens.isEmpty() && chatFormat.getStopTokens().contains(finalTokens.getLast())) { - finalTokens.removeLast(); - } - String finalAnswer = model.tokenizer().decode(finalTokens); - - boolean hitLimit = callsMade.size() >= options.maxRoundTrips() - && chatFormat.extractToolCall(finalAnswer).isPresent(); - - return new ToolCallingResult(finalAnswer, callsMade, toolResults, hitLimit); - - } finally { - if (plan != null) plan.freeTornadoExecutionPlan(); - } - } - - // ── Inference ───────────────────────────────────────────────────────────── - - private List generateTokens(State state, TornadoVMMasterPlan plan, - List prompt, Set stopTokens) { - IntConsumer tokenConsumer = options.verbose() ? this::printToken : null; - - if (options.useGpu()) { - return model.generateTokensGPU( - state, 0, prompt, stopTokens, options.maxTokens(), sampler, - false, tokenConsumer, plan); - } else { - return model.generateTokens( - state, 0, prompt, stopTokens, options.maxTokens(), sampler, - false, tokenConsumer); - } - } - - private void printToken(int token) { - if (model.tokenizer().shouldDisplayToken(token)) { - System.out.print(model.tokenizer().decode(List.of(token))); - System.out.flush(); - } - } - - // ── Helpers ─────────────────────────────────────────────────────────────── - - private String truncate(String text) { - if (text == null) return ""; - if (text.length() <= options.maxToolResultChars()) return text; - return text.substring(0, options.maxToolResultChars()) + "\n... (truncated)"; - } - - private void log(String fmt, Object... args) { - if (options.verbose()) { - System.out.printf((fmt) + "%n", args); - } - } -} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java b/src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java deleted file mode 100644 index a07695c5..00000000 --- a/src/main/java/org/beehive/gpullama3/tools/ToolDefinition.java +++ /dev/null @@ -1,17 +0,0 @@ -package org.beehive.gpullama3.tools; - -/** - * Framework-agnostic description of a tool available to the model. - * - * @param name unique tool name - * @param description human-readable description used in the model's system prompt - * @param parametersJson JSON Schema object for the tool's parameters, e.g. - * {@code {"type":"object","properties":{"city":{"type":"string"}},"required":["city"]}} - */ -public record ToolDefinition(String name, String description, String parametersJson) { - - /** Convenience factory for a tool with no parameters. */ - public static ToolDefinition noArgs(String name, String description) { - return new ToolDefinition(name, description, "{\"type\":\"object\",\"properties\":{}}"); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java b/src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java deleted file mode 100644 index b8807486..00000000 --- a/src/main/java/org/beehive/gpullama3/tools/ToolExecutor.java +++ /dev/null @@ -1,13 +0,0 @@ -package org.beehive.gpullama3.tools; - -import org.beehive.gpullama3.model.format.ToolCallExtract; - -/** - * Executes a single tool call and returns the result. - * Implementations are responsible for parsing {@link ToolCallExtract#argumentsJson()} - * and performing the actual action. - */ -@FunctionalInterface -public interface ToolExecutor { - ToolResult execute(ToolCallExtract call); -} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java b/src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java deleted file mode 100644 index e31fa67e..00000000 --- a/src/main/java/org/beehive/gpullama3/tools/ToolRegistry.java +++ /dev/null @@ -1,94 +0,0 @@ -package org.beehive.gpullama3.tools; - -import org.beehive.gpullama3.model.format.ToolCallExtract; - -import java.util.ArrayList; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -/** - * Holds available tool definitions and their executors. - * Registration order is preserved for deterministic JSON output. - */ -public class ToolRegistry { - - private final Map entries = new LinkedHashMap<>(); - - private record Entry(ToolDefinition definition, ToolExecutor executor) {} - - public ToolRegistry register(ToolDefinition definition, ToolExecutor executor) { - entries.put(definition.name(), new Entry(definition, executor)); - return this; - } - - public Optional getDefinition(String name) { - Entry e = entries.get(name); - return e == null ? Optional.empty() : Optional.of(e.definition()); - } - - public Optional getExecutor(String name) { - Entry e = entries.get(name); - return e == null ? Optional.empty() : Optional.of(e.executor()); - } - - public List definitions() { - return entries.values().stream().map(Entry::definition).toList(); - } - - public boolean isEmpty() { - return entries.isEmpty(); - } - - /** - * Executes the named tool, returning a failure result for unknown tools or - * executor exceptions. Never throws. - */ - public ToolResult execute(ToolCallExtract call) { - Optional executor = getExecutor(call.name()); - if (executor.isEmpty()) { - return ToolResult.failure(call.name(), "Unknown tool: " + call.name()); - } - try { - return executor.get().execute(call); - } catch (Exception e) { - return ToolResult.failure(call.name(), "Tool execution failed: " + e.getMessage()); - } - } - - /** - * Serialises all registered tools to the flat JSON array expected by - * {@code LlamaChatFormat.toolSystemPromptSuffix()} and - * {@code Qwen3ChatFormat.toolSystemPromptSuffix()}. - * - * Format: {@code [{"name":…,"description":…,"parameters":{…}}]} - */ - public String toToolsJson() { - List defs = definitions(); - if (defs.isEmpty()) return "[]"; - - StringBuilder sb = new StringBuilder("[\n"); - for (int i = 0; i < defs.size(); i++) { - ToolDefinition d = defs.get(i); - sb.append(" {\n"); - sb.append(" \"name\": \"").append(escapeJson(d.name())).append("\",\n"); - sb.append(" \"description\": \"").append(escapeJson(d.description())).append("\",\n"); - sb.append(" \"parameters\": ").append(d.parametersJson()).append("\n"); - sb.append(" }"); - if (i < defs.size() - 1) sb.append(","); - sb.append("\n"); - } - sb.append("]"); - return sb.toString(); - } - - private static String escapeJson(String s) { - if (s == null) return ""; - return s.replace("\\", "\\\\") - .replace("\"", "\\\"") - .replace("\n", "\\n") - .replace("\r", "\\r") - .replace("\t", "\\t"); - } -} diff --git a/src/main/java/org/beehive/gpullama3/tools/ToolResult.java b/src/main/java/org/beehive/gpullama3/tools/ToolResult.java deleted file mode 100644 index 6752b207..00000000 --- a/src/main/java/org/beehive/gpullama3/tools/ToolResult.java +++ /dev/null @@ -1,24 +0,0 @@ -package org.beehive.gpullama3.tools; - -/** - * The result of executing a single tool call. - * - * @param toolName name of the tool that was invoked - * @param resultText the tool's output (may be JSON or plain text) - * @param error non-null when execution failed; contains the error message - */ -public record ToolResult(String toolName, String resultText, String error) { - - /** Returns true when the tool execution failed. */ - public boolean isError() { - return error != null; - } - - public static ToolResult success(String toolName, String resultText) { - return new ToolResult(toolName, resultText, null); - } - - public static ToolResult failure(String toolName, String errorMessage) { - return new ToolResult(toolName, null, errorMessage); - } -} From 0aa5a1abc05fcbabacd0f4856afdc92a09f72b26 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 15 May 2026 12:49:15 +0300 Subject: [PATCH 13/23] [tools] Add `supportsToolCalling` implementation and enhance tool call parsing --- .../beehive/gpullama3/model/format/ChatFormat.java | 10 ++++++++++ .../gpullama3/model/format/LlamaChatFormat.java | 5 +++++ .../gpullama3/model/format/Qwen3ChatFormat.java | 9 ++++++--- .../gpullama3/model/format/ToolCallParserUtils.java | 11 ++++++++--- 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index c1ceb267..c8ffbbf2 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -37,6 +37,16 @@ default ChatTokens chatTokens() { Set getStopTokens(); + /** + * Returns {@code true} when this chat format supports tool calling. + * Formats that implement tool-calling methods must override this to return {@code true}. + * Callers should check this before passing tool specifications to avoid hitting the + * default {@link UnsupportedOperationException} deep inside a format method. + */ + default boolean supportsToolCalling() { + return false; + } + /** * Returns plain text to append to the system message content when tools are available. * Used by formats that inject tool definitions into the system message. diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index 421a5513..78461060 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -77,6 +77,11 @@ public List encodeDialogPrompt(boolean appendAssistantTurn, Listfirst user message * (the GGUF-embedded chat template has {@code tools_in_user_message = true} by default). diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index 35930ff1..4e6c7335 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -14,7 +14,6 @@ public class Qwen3ChatFormat implements ChatFormat { protected final int endHeader; protected final int endOfTurn; protected final int endOfText; - protected final int endOfMessage; protected final int endOfTextFim; protected final int imStart; // beginOfText protected final int imEnd; // endOfText @@ -28,13 +27,12 @@ public Qwen3ChatFormat(Qwen3Tokenizer tokenizer, ChatTokens chatTokens) { this.tokenizer = tokenizer; this.chatTokens = chatTokens; Map specialTokens = tokenizer.getSpecialTokens(); - this.beginOfText = specialTokens.getOrDefault("", -1); + this.beginOfText = -1; // Qwen3 has no BOS token; getBeginOfText() falls back to startHeader this.startHeader = specialTokens.getOrDefault(chatTokens.tStartHeader(), -1); this.endHeader = specialTokens.getOrDefault(chatTokens.tEndHeader(), -1); this.endOfTurn = specialTokens.getOrDefault(chatTokens.tEndOfTurn(), -1); this.endOfText = specialTokens.getOrDefault(chatTokens.tEndOfText(), -1); this.endOfTextFim = specialTokens.getOrDefault(chatTokens.tEndOfTextFim(), -1); - this.endOfMessage = specialTokens.getOrDefault("", -1); // Use default value if key not found this.imStart = startHeader; this.imEnd = endHeader; @@ -132,6 +130,11 @@ public Set getStopTokens() { // ── Tool calling ────────────────────────────────────────────────────────── + @Override + public boolean supportsToolCalling() { + return true; + } + /** * Qwen3 tool calling system prompt suffix. * Appended to the system message; instructs the model to wrap tool calls in diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java index 146148fc..91446179 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java @@ -87,7 +87,7 @@ private static Optional parseLlamaJson(String json) { * Returns an empty list when the response contains no tool calls. */ public static List parseAllToolCalls(String responseText) { - List calls = new java.util.ArrayList<>(); + List calls = new ArrayList<>(); // <|python_tag|> (Llama 3.1) — single call by definition int pythonIdx = responseText.indexOf("<|python_tag|>"); @@ -193,6 +193,8 @@ public static String extractStringValue(String json, String key) { /** * Extracts the JSON object value for {@code "key": {…}} using brace-counting. * Handles nested objects and tolerates whitespace around {@code :}. + * Array brackets {@code […]} are tracked so that {@code {}/{}} characters inside + * array elements do not affect the outer brace depth counter. */ public static String extractNestedObject(String json, String key) { String marker = "\"" + key + "\""; @@ -203,10 +205,13 @@ public static String extractNestedObject(String json, String key) { int braceStart = json.indexOf('{', colonIdx + 1); if (braceStart == -1) return null; int depth = 0; + int arrayDepth = 0; for (int i = braceStart; i < json.length(); i++) { char c = json.charAt(i); - if (c == '{') depth++; - else if (c == '}') { + if (c == '[') arrayDepth++; + else if (c == ']') arrayDepth--; + else if (arrayDepth == 0 && c == '{') depth++; + else if (arrayDepth == 0 && c == '}') { if (--depth == 0) return json.substring(braceStart, i + 1); } } From f72ba117ed0e2e9d8ee5576f51c7f195818e93da Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 15 May 2026 12:57:52 +0300 Subject: [PATCH 14/23] [tools] Extend `ToolCallExtract` with optional `id` and unify tool call JSON parsing logic --- .../model/format/ToolCallExtract.java | 11 ++++++++- .../model/format/ToolCallParserUtils.java | 24 +++++++++---------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java index 335c2176..b5f82c51 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java @@ -1,11 +1,20 @@ package org.beehive.gpullama3.model.format; +import java.util.Optional; + /** * Represents a single tool call extracted from a model response. * Contains the raw strings — JSON parsing of arguments is left to the caller. * * @param name the tool/function name to invoke * @param argumentsJson the arguments as a JSON object string, e.g. {"location":"Boston"} + * @param id optional tool call ID parsed from the model response; callers that + * generate IDs themselves (e.g. Ollama-style "call_XXXXXXXX") may pass + * {@link Optional#empty()} and let the consumer generate one */ -public record ToolCallExtract(String name, String argumentsJson) { +public record ToolCallExtract(String name, String argumentsJson, Optional id) { + + public ToolCallExtract(String name, String argumentsJson) { + this(name, argumentsJson, Optional.empty()); + } } diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java index 91446179..906694dd 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java @@ -33,7 +33,7 @@ public static Optional parseLlamaResponse(String responseText) int idx = responseText.indexOf("<|python_tag|>"); if (idx != -1) { String json = responseText.substring(idx + "<|python_tag|>".length()).strip(); - return parseLlamaJson(json); + return parseToolCallJson(json); } // 2. LLaMA 3.2 format: ... @@ -41,29 +41,29 @@ public static Optional parseLlamaResponse(String responseText) int tcEnd = responseText.lastIndexOf("
"); if (tcStart != -1 && tcEnd != -1 && tcEnd > tcStart) { String json = responseText.substring(tcStart + "".length(), tcEnd).strip(); - return parseLlamaJson(json); + return parseToolCallJson(json); } // 2b. Unclosed — model stopped (eot_id / eom_id) before writing the closing tag if (tcStart != -1 && tcEnd == -1) { String json = responseText.substring(tcStart + "".length()).strip(); - return parseLlamaJson(json); + return parseToolCallJson(json); } // 3. Fallback: raw JSON, possibly inside markdown code fences String stripped = stripMarkdownFences(responseText.strip()); if (stripped.startsWith("{")) { - return parseLlamaJson(stripped); + return parseToolCallJson(stripped); } return Optional.empty(); } /** - * Parses a LLaMA-style tool call JSON object. + * Parses a tool call JSON object extracted from a {@code } block or raw JSON. * Accepts {@code {"name":…,"parameters":{…}}}, {@code {"function":…,"parameters":{…}}}, - * and {@code {"name":…,"arguments":{…}}}. + * and {@code {"name":…,"arguments":{…}}} — covering both LLaMA and Qwen3 variants. */ - private static Optional parseLlamaJson(String json) { + private static Optional parseToolCallJson(String json) { String name = extractStringValue(json, "name"); if (name == null) { name = extractStringValue(json, "function"); @@ -92,7 +92,7 @@ public static List parseAllToolCalls(String responseText) { // <|python_tag|> (Llama 3.1) — single call by definition int pythonIdx = responseText.indexOf("<|python_tag|>"); if (pythonIdx != -1) { - parseLlamaJson(responseText.substring(pythonIdx + "<|python_tag|>".length()).strip()) + parseToolCallJson(responseText.substring(pythonIdx + "<|python_tag|>".length()).strip()) .ifPresent(calls::add); return calls; } @@ -112,7 +112,7 @@ public static List parseAllToolCalls(String responseText) { json = responseText.substring(start + "".length()).strip(); searchFrom = responseText.length(); } - parseLlamaJson(json).ifPresent(calls::add); + parseToolCallJson(json).ifPresent(calls::add); if (end == -1) break; } @@ -120,7 +120,7 @@ public static List parseAllToolCalls(String responseText) { if (calls.isEmpty()) { String stripped = stripMarkdownFences(responseText.strip()); if (stripped.startsWith("{")) { - parseLlamaJson(stripped).ifPresent(calls::add); + parseToolCallJson(stripped).ifPresent(calls::add); } } @@ -166,7 +166,7 @@ public static String stripMarkdownFences(String text) { * Tolerates whitespace around {@code :} and correctly skips escaped quotes ({@code \"}) * inside the value, so multi-line code strings with embedded {@code "} are returned intact. */ - public static String extractStringValue(String json, String key) { + private static String extractStringValue(String json, String key) { String marker = "\"" + key + "\""; int markerIdx = json.indexOf(marker); if (markerIdx == -1) return null; @@ -196,7 +196,7 @@ public static String extractStringValue(String json, String key) { * Array brackets {@code […]} are tracked so that {@code {}/{}} characters inside * array elements do not affect the outer brace depth counter. */ - public static String extractNestedObject(String json, String key) { + private static String extractNestedObject(String json, String key) { String marker = "\"" + key + "\""; int markerIdx = json.indexOf(marker); if (markerIdx == -1) return null; From 2606f8b26a9e99fce76ac2b815d4f08c77a4bab0 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 15 May 2026 19:04:18 +0300 Subject: [PATCH 15/23] [tools] Add support for batch tool call encoding across multiple chat formats --- .../gpullama3/model/format/ChatFormat.java | 22 +++++++++++++++++++ .../model/format/LlamaChatFormat.java | 18 +++++++++++++++ .../model/format/Qwen3ChatFormat.java | 22 +++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index c8ffbbf2..076c2a25 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -7,6 +7,7 @@ import org.beehive.gpullama3.tokenizer.Phi3Tokenizer; import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.Set; @@ -107,6 +108,27 @@ default List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); } + /** + * Re-encodes a prior assistant turn that contained one or more tool calls as a + * single assistant message. Implementations must emit all calls inside one + * header/footer pair so the model does not see spurious assistant turn boundaries. + * + *

The default delegates to {@link #encodeToolCallAssistantTurn(ToolCallExtract)} + * for single-element lists and naively concatenates individual encodings for larger + * lists — formats that support batch tool calls should override this method. + * + * @param toolCalls the ordered list of tool calls from a single assistant turn + */ + default List encodeToolCallAssistantTurn(List toolCalls) { + if (toolCalls.isEmpty()) return List.of(); + if (toolCalls.size() == 1) return encodeToolCallAssistantTurn(toolCalls.get(0)); + List tokens = new ArrayList<>(); + for (ToolCallExtract tc : toolCalls) { + tokens.addAll(encodeToolCallAssistantTurn(tc)); + } + return tokens; + } + /** * Encodes a tool execution result message in the model-native format. * diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index 78461060..324843ed 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -162,6 +162,24 @@ public List encodeToolResultTurn(String toolCallId, String toolName, St return tokens; } + /** + * Encodes multiple tool calls as a single assistant turn using {@code } blocks. + * For a single call, delegates to the existing single-call method (preserving the + * {@code <|python_tag|>} prefix on LLaMA 3.1). + */ + @Override + public List encodeToolCallAssistantTurn(List toolCalls) { + if (toolCalls.isEmpty()) return List.of(); + if (toolCalls.size() == 1) return encodeToolCallAssistantTurn(toolCalls.get(0)); + List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); + for (ToolCallExtract tc : toolCalls) { + String json = "{\"name\": \"" + tc.name() + "\", \"parameters\": " + tc.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeAsList("\n" + json + "\n\n")); + } + tokens.add(endOfMessage != -1 ? endOfMessage : endOfTurn); + return tokens; + } + /** * Detects a tool call in the decoded response text. * Supports LLaMA 3.1 (native {@code <|python_tag|>} + {@code "parameters"} key), diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index 4e6c7335..2006390a 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -172,6 +172,28 @@ public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { return tokens; } + /** + * Encodes multiple tool calls as a single assistant turn: one {@code <|im_start|>assistant} + * header, all {@code } blocks concatenated, then {@code <|im_end|>}. + * For a single call, delegates to the existing single-call method. + */ + @Override + public List encodeToolCallAssistantTurn(List toolCalls) { + if (toolCalls.isEmpty()) return List.of(); + if (toolCalls.size() == 1) return encodeToolCallAssistantTurn(toolCalls.get(0)); + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("assistant\n")); + for (ToolCallExtract tc : toolCalls) { + String json = "{\"name\":\"" + tc.name() + "\",\"arguments\":" + tc.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeOrdinaryAsList("\n" + json + "\n")); + } + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + /** * Encodes a tool result using the Qwen3 "tool" role. * Format: {@code <|im_start|>tool\nresult<|im_end|>} From 8295ebd8cc793c41a027b01b61ed7dc1fe0634c1 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 15 May 2026 21:13:25 +0300 Subject: [PATCH 16/23] [tools] Unify tool call parsing logic, streamline method names, and extend support for multi-model formats --- .../model/format/LlamaChatFormat.java | 15 +++++-- .../model/format/Qwen3ChatFormat.java | 4 +- .../model/format/ToolCallParserUtils.java | 41 ++++--------------- 3 files changed, 22 insertions(+), 38 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index 324843ed..7779edd4 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -163,9 +163,11 @@ public List encodeToolResultTurn(String toolCallId, String toolName, St } /** - * Encodes multiple tool calls as a single assistant turn using {@code } blocks. + * Encodes multiple tool calls as a single assistant turn. * For a single call, delegates to the existing single-call method (preserving the * {@code <|python_tag|>} prefix on LLaMA 3.1). + * For multiple calls, LLaMA 3.1 prefixes each with {@code <|python_tag|>}; + * LLaMA 3.2 (no python_tag) uses {@code } blocks. */ @Override public List encodeToolCallAssistantTurn(List toolCalls) { @@ -174,7 +176,12 @@ public List encodeToolCallAssistantTurn(List toolCalls List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); for (ToolCallExtract tc : toolCalls) { String json = "{\"name\": \"" + tc.name() + "\", \"parameters\": " + tc.argumentsJson() + "}"; - tokens.addAll(tokenizer.encodeAsList("\n" + json + "\n\n")); + if (pythonTag != -1) { + tokens.add(pythonTag); + tokens.addAll(tokenizer.encodeAsList(json + "\n")); + } else { + tokens.addAll(tokenizer.encodeAsList("\n" + json + "\n\n")); + } } tokens.add(endOfMessage != -1 ? endOfMessage : endOfTurn); return tokens; @@ -184,11 +191,11 @@ public List encodeToolCallAssistantTurn(List toolCalls * Detects a tool call in the decoded response text. * Supports LLaMA 3.1 (native {@code <|python_tag|>} + {@code "parameters"} key), * LLaMA 3.2 ({@code "arguments"} key, tag often absent), and a raw-JSON fallback - * for smaller models. Delegates to {@link ToolCallParserUtils#parseLlamaResponse}. + * for smaller models. Delegates to {@link ToolCallParserUtils#parseToolCallResponse}. */ @Override public Optional extractToolCall(String responseText) { - return ToolCallParserUtils.parseLlamaResponse(responseText); + return ToolCallParserUtils.parseToolCallResponse(responseText); } @Override diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index 2006390a..43a8a8ec 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -212,11 +212,11 @@ public List encodeToolResultTurn(String toolCallId, String toolName, St /** * Detects a tool call enclosed in {@code } tags. - * Delegates to {@link ToolCallParserUtils#parseQwen3Response}. + * Delegates to {@link ToolCallParserUtils#parseToolCallResponse}. */ @Override public Optional extractToolCall(String responseText) { - return ToolCallParserUtils.parseQwen3Response(responseText); + return ToolCallParserUtils.parseToolCallResponse(responseText); } @Override diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java index 906694dd..fb94fa38 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java @@ -17,18 +17,17 @@ private ToolCallParserUtils() {} // ── Llama ───────────────────────────────────────────────────────────────── /** - * Extracts a tool call from a LLaMA 3.1 or 3.2 model response. + * Extracts a single tool call from a model response text. * - * Recognised formats: - * 1. {@code <|python_tag|>{"name":…,"parameters":{…}}} — LLaMA 3.1 native, also accepted by 3.2 - * 2. Raw JSON with {@code "arguments"} key instead of {@code "parameters"} — LLaMA 3.2 instruction format - * 3. Raw JSON object optionally inside markdown code fences — fallback for models that - * follow system-prompt instructions but omit the special-token prefix + * Recognised formats (in priority order): + * 1. {@code <|python_tag|>{…}} — LLaMA 3.1 native + * 2. {@code } — LLaMA 3.2 and Qwen3 (closed or unclosed) + * 3. Raw JSON object optionally inside markdown code fences — fallback * - * Both {@code "parameters"} and {@code "arguments"} are tried so a single implementation - * handles the 3.1 and 3.2 variants transparently. + * Both {@code "parameters"} and {@code "arguments"} are tried as the argument key, + * covering LLaMA 3.1/3.2 and Qwen3 variants transparently. */ - public static Optional parseLlamaResponse(String responseText) { + public static Optional parseToolCallResponse(String responseText) { // 1. Native LLaMA 3.1 format: <|python_tag|>{...} int idx = responseText.indexOf("<|python_tag|>"); if (idx != -1) { @@ -77,7 +76,7 @@ private static Optional parseToolCallJson(String json) { return Optional.of(new ToolCallExtract(name, argsJson)); } - // ── Qwen3 ───────────────────────────────────────────────────────────────── + // ── Batch extraction ────────────────────────────────────────────────────── /** * Extracts ALL tool calls from a response that may contain multiple @@ -127,28 +126,6 @@ public static List parseAllToolCalls(String responseText) { return calls; } - /** - * Extracts a tool call enclosed in {@code } tags - * as produced by Qwen3 models. - */ - public static Optional parseQwen3Response(String responseText) { - int start = responseText.indexOf(""); - int end = responseText.lastIndexOf(""); - if (start == -1) return Optional.empty(); - - String json = (end != -1 && end > start) - ? responseText.substring(start + "".length(), end).strip() - : responseText.substring(start + "".length()).strip(); - - String name = extractStringValue(json, "name"); - if (name == null) return Optional.empty(); - - String argsJson = extractNestedObject(json, "arguments"); - if (argsJson == null) argsJson = "{}"; - - return Optional.of(new ToolCallExtract(name, argsJson)); - } - // ── Shared helpers ──────────────────────────────────────────────────────── /** Strips surrounding markdown code fences (```…```) if present. */ From ce83aea3344696799a236013a85f02d4b49352e8 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 15 May 2026 21:13:54 +0300 Subject: [PATCH 17/23] Add default temperature and top-p resolution based on model formats, adjust `Options` validation and defaults --- src/main/java/org/beehive/gpullama3/Options.java | 16 ++++++++-------- .../gpullama3/inference/sampler/Sampler.java | 8 +++++++- .../gpullama3/model/format/ChatFormat.java | 16 ++++++++++++++++ .../gpullama3/model/format/LlamaChatFormat.java | 10 ++++++++++ .../gpullama3/model/format/Qwen3ChatFormat.java | 10 ++++++++++ 5 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/Options.java b/src/main/java/org/beehive/gpullama3/Options.java index 919f9751..274428b7 100644 --- a/src/main/java/org/beehive/gpullama3/Options.java +++ b/src/main/java/org/beehive/gpullama3/Options.java @@ -11,8 +11,8 @@ public record Options(Path modelPath, String prompt, String systemPrompt, String public Options { require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\""); - require(0 <= temperature, "Invalid argument: --temperature must be non-negative"); - require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); + require(Float.isNaN(temperature) || 0 <= temperature, "Invalid argument: --temperature must be non-negative"); + require(Float.isNaN(topp) || 0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); require(batchPrefillSize >= 1, "Invalid argument: --batch-prefill-size must be >= 1"); require(batchPrefillSize == 1 || withPrefillDecode, "Invalid argument: --batch-prefill-size requires --with-prefill-decode"); // Publish to system properties so TornadoVMMasterPlan and Llama read the right values @@ -44,8 +44,8 @@ public static void printUsage(PrintStream out) { out.println(" --prompt, -p input prompt"); out.println(" --system-prompt, -sp (optional) system prompt (Llama models)"); out.println(" --suffix suffix for fill-in-the-middle request (Codestral)"); - out.println(" --temperature, -temp temperature in [0,inf], default 0.1"); - out.println(" --top-p p value in top-p (nucleus) sampling in [0,1] default 0.95"); + out.println(" --temperature, -temp temperature in [0,inf], default: auto-detected from model family"); + out.println(" --top-p p value in top-p (nucleus) sampling in [0,1], default: auto-detected from model family"); out.println(" --seed random seed, default System.nanoTime()"); out.println(" --max-tokens, -n number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS); out.println(" --stream print tokens during generation; may cause encoding artifacts for non ASCII text, default true"); @@ -59,8 +59,8 @@ public static Options getDefaultOptions() { String prompt = "Tell me a story with Java"; // Hardcoded for testing String systemPrompt = null; String suffix = null; - float temperature = 0.1f; - float topp = 0.95f; + float temperature = Float.NaN; // resolved from model family after loading + float topp = Float.NaN; // resolved from model family after loading Path modelPath = null; long seed = System.nanoTime(); int maxTokens = DEFAULT_MAX_TOKENS; @@ -76,8 +76,8 @@ public static Options parseOptions(String[] args) { String prompt = "Tell me a story with Java"; // Hardcoded for testing String systemPrompt = null; String suffix = null; - float temperature = 0.1f; - float topp = 0.95f; + float temperature = Float.NaN; // resolved from model family after loading + float topp = Float.NaN; // resolved from model family after loading Path modelPath = null; long seed = System.nanoTime(); int maxTokens = DEFAULT_MAX_TOKENS; diff --git a/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java b/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java index 496d0761..bbb7c5ad 100644 --- a/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java +++ b/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java @@ -122,7 +122,13 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp, } static Sampler createSampler(Model model, Options options) { - return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); + float temperature = Float.isNaN(options.temperature()) + ? (float) model.chatFormat().defaultTemperature() + : options.temperature(); + float topp = Float.isNaN(options.topp()) + ? (float) model.chatFormat().defaultTopP() + : options.topp(); + return selectSampler(model.configuration().vocabularySize(), temperature, topp, options.seed()); } /** diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index 076c2a25..9f7121d4 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -162,6 +162,22 @@ default List extractAllToolCalls(String responseText) { return extractToolCall(responseText).map(List::of).orElse(List.of()); } + /** + * Returns the recommended default temperature for this chat format. + * Used when the caller has not explicitly configured a temperature. + */ + default double defaultTemperature() { + return 0.7; + } + + /** + * Returns the recommended default top-p for this chat format. + * Used when the caller has not explicitly configured a top-p value. + */ + default double defaultTopP() { + return 0.9; + } + /** * Stop tokens to use when tool calling is enabled. * Some models (LLaMA 3.1+) use a different end-of-turn token ({@code <|eom_id|>}) diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index 7779edd4..f23e3c26 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -75,6 +75,16 @@ public List encodeDialogPrompt(boolean appendAssistantTurn, List getStopTokens() { return stopTokens; } + @Override + public double defaultTemperature() { + return 0.8; + } + + @Override + public double defaultTopP() { + return 0.9; + } + // ── Tool calling ────────────────────────────────────────────────────────── @Override From 76058fa2434cb6ec1aa26f2593941af5a85ab424 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 15 May 2026 21:31:26 +0300 Subject: [PATCH 18/23] [tools] Simplify section comments in `ToolCallParserUtils` for improved readability --- .../beehive/gpullama3/model/format/ToolCallParserUtils.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java index fb94fa38..a0a856a0 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java @@ -14,8 +14,6 @@ public final class ToolCallParserUtils { private ToolCallParserUtils() {} - // ── Llama ───────────────────────────────────────────────────────────────── - /** * Extracts a single tool call from a model response text. * @@ -76,7 +74,7 @@ private static Optional parseToolCallJson(String json) { return Optional.of(new ToolCallExtract(name, argsJson)); } - // ── Batch extraction ────────────────────────────────────────────────────── + // Batch extraction /** * Extracts ALL tool calls from a response that may contain multiple @@ -126,7 +124,7 @@ public static List parseAllToolCalls(String responseText) { return calls; } - // ── Shared helpers ──────────────────────────────────────────────────────── + // Shared helpers /** Strips surrounding markdown code fences (```…```) if present. */ public static String stripMarkdownFences(String text) { From a28731212e363df4d459c7c290eef9555ad0c2b3 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 12 Jun 2026 14:52:42 +0300 Subject: [PATCH 19/23] [tools][fix] Enhance JSON parsing in `ToolCallParserUtils` with string-aware brace counting --- .../model/format/ToolCallParserUtils.java | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java index a0a856a0..0691b4e7 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java @@ -168,8 +168,11 @@ private static String extractStringValue(String json, String key) { /** * Extracts the JSON object value for {@code "key": {…}} using brace-counting. * Handles nested objects and tolerates whitespace around {@code :}. - * Array brackets {@code […]} are tracked so that {@code {}/{}} characters inside - * array elements do not affect the outer brace depth counter. + * + *

Brace counting is string-aware: {@code {} and } characters appearing inside + * JSON string literals (e.g. a {@code "code"} argument containing Java source) do not affect + * the depth counter, and {@code \"} escapes inside strings are skipped. This keeps argument + * objects whose string values contain braces intact. */ private static String extractNestedObject(String json, String key) { String marker = "\"" + key + "\""; @@ -180,13 +183,20 @@ private static String extractNestedObject(String json, String key) { int braceStart = json.indexOf('{', colonIdx + 1); if (braceStart == -1) return null; int depth = 0; - int arrayDepth = 0; + boolean inString = false; for (int i = braceStart; i < json.length(); i++) { char c = json.charAt(i); - if (c == '[') arrayDepth++; - else if (c == ']') arrayDepth--; - else if (arrayDepth == 0 && c == '{') depth++; - else if (arrayDepth == 0 && c == '}') { + if (inString) { + if (c == '\\') { + i++; // skip the escaped character (e.g. \", \\, \n) + } else if (c == '"') { + inString = false; + } + } else if (c == '"') { + inString = true; + } else if (c == '{') { + depth++; + } else if (c == '}') { if (--depth == 0) return json.substring(braceStart, i + 1); } } From 828b2f02846f9e4f376a48bdb0648cddc0880a6d Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 12 Jun 2026 14:53:34 +0300 Subject: [PATCH 20/23] [tools][fix] Update Qwen3ChatFormat to encode tool results using `` tags --- .../beehive/gpullama3/model/format/Qwen3ChatFormat.java | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index 8fc761c1..0b12c8de 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -205,15 +205,16 @@ public List encodeToolCallAssistantTurn(List toolCalls } /** - * Encodes a tool result using the Qwen3 "tool" role. - * Format: {@code <|im_start|>tool\nresult<|im_end|>} + * Encodes a tool result in the native Qwen3 format: a {@code user} turn whose content is the + * result wrapped in {@code } tags, matching the official Qwen3 + * chat template. (Qwen3 has no dedicated "tool" role — results are delivered as user turns.) + * Format: {@code <|im_start|>user\n\nresult\n<|im_end|>} */ @Override public List encodeToolResultTurn(String toolCallId, String toolName, String result) { List tokens = new ArrayList<>(); tokens.add(imStart); - tokens.addAll(tokenizer.encodeOrdinaryAsList("tool\n")); - tokens.addAll(tokenizer.encodeOrdinaryAsList(result)); + tokens.addAll(tokenizer.encodeOrdinaryAsList("user\n\n" + result + "\n")); if (imEnd != -1) { tokens.add(imEnd); } From fa8429edb4271142c533e1208ccd4fe186b56333 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 12 Jun 2026 14:54:14 +0300 Subject: [PATCH 21/23] [tools][test] Add unit tests for `ToolCallParserUtils`, covering single and batch parsing scenarios --- .../model/format/ToolCallParserUtilsTest.java | 166 ++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 src/test/java/org/beehive/gpullama3/model/format/ToolCallParserUtilsTest.java diff --git a/src/test/java/org/beehive/gpullama3/model/format/ToolCallParserUtilsTest.java b/src/test/java/org/beehive/gpullama3/model/format/ToolCallParserUtilsTest.java new file mode 100644 index 00000000..74e74ea6 --- /dev/null +++ b/src/test/java/org/beehive/gpullama3/model/format/ToolCallParserUtilsTest.java @@ -0,0 +1,166 @@ +package org.beehive.gpullama3.model.format; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.util.List; +import java.util.Optional; + +import org.junit.Test; + +/** + * Unit tests for {@link ToolCallParserUtils}. + * + *

The parser is pure string-handling (no model/tokenizer), so every recognised response + * shape is covered here: Qwen3/Llama {@code } tags, Llama 3.1 {@code <|python_tag|>}, + * raw-JSON and markdown-fence fallbacks, unclosed tags, and batch (multi-call) responses. + * The brace-in-string cases pin the fix that keeps argument objects whose string values + * contain {@code {}/{}} characters (e.g. source code) intact. + */ +public class ToolCallParserUtilsTest { + + // ── Single-call extraction ──────────────────────────────────────────────── + + @Test + public void qwen3ToolCall_arguments() { + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "\n{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Chania\"}}\n"); + assertTrue(tc.isPresent()); + assertEquals("get_weather", tc.get().name()); + assertEquals("{\"city\": \"Chania\"}", tc.get().argumentsJson()); + } + + @Test + public void llama31_pythonTag_parametersKey() { + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "<|python_tag|>{\"name\": \"get_weather\", \"parameters\": {\"city\": \"Boston\"}}"); + assertTrue(tc.isPresent()); + assertEquals("get_weather", tc.get().name()); + assertEquals("{\"city\": \"Boston\"}", tc.get().argumentsJson()); + } + + @Test + public void functionKey_usedAsNameFallback() { + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"function\": \"list_dir\", \"arguments\": {\"path\": \"/tmp\"}}"); + assertTrue(tc.isPresent()); + assertEquals("list_dir", tc.get().name()); + } + + @Test + public void missingArguments_defaultsToEmptyObject() { + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"name\": \"now\"}"); + assertTrue(tc.isPresent()); + assertEquals("now", tc.get().name()); + assertEquals("{}", tc.get().argumentsJson()); + } + + @Test + public void unclosedToolCall_stillParsed() { + // Model stopped (eot/eom) before emitting the closing tag. + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"name\": \"ping\", \"arguments\": {\"host\": \"a\"}}"); + assertTrue(tc.isPresent()); + assertEquals("ping", tc.get().name()); + assertEquals("{\"host\": \"a\"}", tc.get().argumentsJson()); + } + + @Test + public void plainTextResponse_isNotAToolCall() { + assertFalse(ToolCallParserUtils.parseToolCallResponse("The weather in Chania is sunny.").isPresent()); + } + + // ── Brace-in-string argument objects (the core fix) ─────────────────────── + + @Test + public void argumentsWithBracesInStringValue_keptIntact() { + String args = "{\"code\": \"public class A { void m() { return; } }\"}"; + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"name\": \"write_file\", \"arguments\": " + args + "}"); + assertTrue(tc.isPresent()); + assertEquals("write_file", tc.get().name()); + assertEquals(args, tc.get().argumentsJson()); + } + + @Test + public void argumentsWithEscapedQuotesAndBraces_keptIntact() { + String args = "{\"snippet\": \"if (s.equals(\\\"}\\\")) { x++; }\"}"; + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"name\": \"run\", \"arguments\": " + args + "}"); + assertTrue(tc.isPresent()); + assertEquals(args, tc.get().argumentsJson()); + } + + @Test + public void argumentsWithNestedObjectsAndArrays_keptIntact() { + String args = "{\"items\": [{\"a\": 1}, {\"b\": 2}], \"meta\": {\"n\": 3}}"; + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"name\": \"batch\", \"arguments\": " + args + "}"); + assertTrue(tc.isPresent()); + assertEquals(args, tc.get().argumentsJson()); + } + + // ── Fallbacks ───────────────────────────────────────────────────────────── + + @Test + public void rawJsonFallback_noTags() { + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"name\": \"echo\", \"arguments\": {\"msg\": \"hi\"}}"); + assertTrue(tc.isPresent()); + assertEquals("echo", tc.get().name()); + } + + @Test + public void markdownFencedJson_fallback() { + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "```json\n{\"name\": \"echo\", \"arguments\": {\"msg\": \"hi\"}}\n```"); + assertTrue(tc.isPresent()); + assertEquals("echo", tc.get().name()); + assertEquals("{\"msg\": \"hi\"}", tc.get().argumentsJson()); + } + + @Test + public void stripMarkdownFences_removesFenceLines() { + assertEquals("body", ToolCallParserUtils.stripMarkdownFences("```\nbody\n```")); + assertEquals("plain", ToolCallParserUtils.stripMarkdownFences("plain")); + } + + // ── Batch (multiple tool calls) ─────────────────────────────────────────── + + @Test + public void batch_multipleToolCallBlocks() { + List calls = ToolCallParserUtils.parseAllToolCalls( + "{\"name\": \"a\", \"arguments\": {\"x\": 1}}" + + "{\"name\": \"b\", \"arguments\": {\"y\": 2}}"); + assertEquals(2, calls.size()); + assertEquals("a", calls.get(0).name()); + assertEquals("{\"x\": 1}", calls.get(0).argumentsJson()); + assertEquals("b", calls.get(1).name()); + assertEquals("{\"y\": 2}", calls.get(1).argumentsJson()); + } + + @Test + public void batch_bracesInStringDoNotBleedAcrossCalls() { + List calls = ToolCallParserUtils.parseAllToolCalls( + "{\"name\": \"write\", \"arguments\": {\"code\": \"a { b }\"}}" + + "{\"name\": \"log\", \"arguments\": {\"msg\": \"ok\"}}"); + assertEquals(2, calls.size()); + assertEquals("{\"code\": \"a { b }\"}", calls.get(0).argumentsJson()); + assertEquals("log", calls.get(1).name()); + } + + @Test + public void batch_pythonTagIsSingleCall() { + List calls = ToolCallParserUtils.parseAllToolCalls( + "<|python_tag|>{\"name\": \"a\", \"parameters\": {\"x\": 1}}"); + assertEquals(1, calls.size()); + assertEquals("a", calls.get(0).name()); + } + + @Test + public void batch_noToolCalls_returnsEmpty() { + assertTrue(ToolCallParserUtils.parseAllToolCalls("just a plain answer").isEmpty()); + } +} From 674170ffaeb723106607d9f8e351377b8c3f1077 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 12 Jun 2026 15:29:51 +0300 Subject: [PATCH 22/23] Add thinking on/off control support --- .../gpullama3/model/format/ChatFormat.java | 24 +++++++++++++++++ .../model/format/Qwen3ChatFormat.java | 26 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index 9f7121d4..d23738a0 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -187,6 +187,30 @@ default Set getToolAwareStopTokens() { return getStopTokens(); } + /** + * Returns {@code true} when this chat format has a controllable thinking/reasoning mode that + * {@link #encodeThinkingControl(boolean)} can toggle (e.g. Qwen3). Formats that return + * {@code false} (the default) have no reasoning phase to switch on or off, so the + * {@code enableThinking} flag is inert for them. Pure reasoning models that always think and + * offer no off-switch (e.g. DeepSeek-R1) also return {@code false}. + */ + default boolean supportsThinking() { + return false; + } + + /** + * Returns the tokens to append immediately after the assistant header in order to control + * the model's thinking/reasoning phase. Models that do not {@link #supportsThinking()} return + * an empty list (the default), so callers can invoke this unconditionally. + * + * @param enableThinking when {@code false}, returns the model-native primer that suppresses + * reasoning (e.g. Qwen3's pre-closed {@code } block); when {@code true}, + * returns an empty list so the model decides for itself. + */ + default List encodeThinkingControl(boolean enableThinking) { + return List.of(); + } + record ChatTokens(String tStartHeader, String tEndHeader, String tEndOfTurn, String tEndOfText, String tEndOfTextFim) { } diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index 0b12c8de..e7b99e2c 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -138,6 +138,32 @@ public double defaultTopP() { return 0.9; } + /** + * Genuine Qwen3 exposes the {@code enable_thinking} template switch and so supports thinking + * control. DeepSeek-R1 is routed through this same format (detected by the absence of an + * {@code <|im_end|>} token) but is a pure reasoning model with no off-switch, so it reports + * {@code false} and is left to always reason. + */ + @Override + public boolean supportsThinking() { + return imEnd != -1; + } + + /** + * Qwen3 thinking control. When thinking is disabled, primes a pre-closed + * {@code \n\n\n\n} block right after the assistant header so the model skips + * its reasoning phase — matching the {@code enable_thinking=false} branch of the official + * Qwen3 chat template. When enabled (or for DeepSeek-R1, which cannot disable thinking), + * returns nothing and lets the model reason on its own. + */ + @Override + public List encodeThinkingControl(boolean enableThinking) { + if (enableThinking || !supportsThinking()) { + return List.of(); + } + return tokenizer.encodeOrdinaryAsList("\n\n\n\n"); + } + // ── Tool calling ────────────────────────────────────────────────────────── @Override From 91333d13f560fbd3d27ff5acc19f306a1bf6d665 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Mon, 15 Jun 2026 15:03:46 +0300 Subject: [PATCH 23/23] Add support for canonical `` token tracking and usage in thinking control encoding --- .../model/format/Qwen3ChatFormat.java | 18 +++++++++++++++++- .../gpullama3/tokenizer/Qwen3Tokenizer.java | 19 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index e7b99e2c..5d121c31 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -155,13 +155,29 @@ public boolean supportsThinking() { * its reasoning phase — matching the {@code enable_thinking=false} branch of the official * Qwen3 chat template. When enabled (or for DeepSeek-R1, which cannot disable thinking), * returns nothing and lets the model reason on its own. + * + *

The {@code }/{@code } markers are emitted as their canonical + * single token ids (not ordinary BPE sub-pieces): the tokenizer strips them from its special + * map so reasoning renders as text, but the model only recognises the closed block — and thus + * actually skips reasoning — when it sees the real control tokens it was trained on. */ @Override public List encodeThinkingControl(boolean enableThinking) { if (enableThinking || !supportsThinking()) { return List.of(); } - return tokenizer.encodeOrdinaryAsList("\n\n\n\n"); + int thinkStart = tokenizer.getThinkStartToken(); + int thinkEnd = tokenizer.getThinkEndToken(); + if (thinkStart == -1 || thinkEnd == -1) { + // GGUF without dedicated think tokens — fall back to ordinary text encoding. + return tokenizer.encodeOrdinaryAsList("\n\n\n\n"); + } + List tokens = new ArrayList<>(); + tokens.add(thinkStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("\n\n")); + tokens.add(thinkEnd); + tokens.addAll(tokenizer.encodeOrdinaryAsList("\n\n")); + return tokens; } // ── Tool calling ────────────────────────────────────────────────────────── diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java index 077dd536..ea122d44 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java @@ -25,6 +25,11 @@ public class Qwen3Tokenizer implements Tokenizer { private final Map, Integer> merges; private final Map specialTokens; private final int[] tokenTypes; + /** Canonical {@code }/{@code } token ids (or -1), captured before they are + * removed from {@link #specialTokens} so reasoning renders as text. Used to prime the + * thinking-disabled control block with the tokens the model was actually trained on. */ + private final int thinkStartToken; + private final int thinkEndToken; /** buffer to store incomplete UTF-8 sequence */ private final byte[] bufUtf8 = new byte[4]; /** index in UTF-8 buffer */ @@ -59,6 +64,10 @@ public Qwen3Tokenizer(Map metadata, Vocabulary vocabulary, boole i -> specialTokensList.get(i), i -> baseTokens + i) ); + // Capture the canonical think-control token ids BEFORE removing them from the special + // map (they are removed so reasoning text renders verbatim during decode). + this.thinkStartToken = specialTokens.getOrDefault("", -1); + this.thinkEndToken = specialTokens.getOrDefault("", -1); specialTokens.remove(""); specialTokens.remove(""); @@ -145,6 +154,16 @@ public Map getSpecialTokens() { return specialTokens; } + /** Canonical {@code } token id, or {@code -1} if this GGUF has no think tokens. */ + public int getThinkStartToken() { + return thinkStartToken; + } + + /** Canonical {@code } token id, or {@code -1} if this GGUF has no think tokens. */ + public int getThinkEndToken() { + return thinkEndToken; + } + @Override public boolean isSpecialToken(int tokenIndex) { return specialTokens.containsValue(tokenIndex);