diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index 2efab197..f60083cc 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -177,15 +177,62 @@ jobs: flags: --with-prefill-decode --batch-prefill-size 32 --cuda-graphs metrics_file: ${{ runner.temp }}/metrics-ptx-llama-1b-f16-batch-prefill-decode-cuda-graphs.json - - name: FP16 - Run Qwen3-4B-f16.gguf + - name: FP16 - Run Qwen3-0.6B-f16.gguf uses: ./.github/actions/run-inference with: backend: ${{ matrix.backend.name }} - model_file: Qwen3-4B-f16.gguf - model: Qwen3-4B + model_file: Qwen3-0.6B-f16.gguf + model: Qwen3-0.6B quantization: F16 configuration: standard - metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-4b-f16-standard.json + metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0.6b-f16-standard.json + + - name: FP16 - Run Qwen3-0.6B-f16.gguf - Prefill-Decode + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-0.6B-f16.gguf + model: Qwen3-0.6B + quantization: F16 + configuration: prefill-decode + flags: --with-prefill-decode + metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0.6b-f16-prefill-decode.json + + - name: FP16 - Run Qwen3-0.6B-f16.gguf - Batch-Prefill-Decode + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-0.6B-f16.gguf + model: Qwen3-0.6B + quantization: F16 + configuration: batch-prefill-decode + flags: --with-prefill-decode --batch-prefill-size 32 + metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0.6b-f16-batch-prefill-decode.json + + # PTX-only: CUDA-graph variants + - name: PTX - FP16 - Run Qwen3-0.6B-f16.gguf - Prefill-Decode-CUDA-Graphs + if: matrix.backend.name == 'ptx' + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-0.6B-f16.gguf + model: Qwen3-0.6B + quantization: F16 + configuration: prefill-decode-cuda-graphs + flags: --with-prefill-decode --cuda-graphs + metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-0.6b-f16-prefill-decode-cuda-graphs.json + + - name: PTX - FP16 - Run Qwen3-0.6B-f16.gguf - Batch-Prefill-Decode-CUDA-Graphs + if: matrix.backend.name == 'ptx' + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-0.6B-f16.gguf + model: Qwen3-0.6B + quantization: F16 + configuration: batch-prefill-decode-cuda-graphs + flags: --with-prefill-decode --batch-prefill-size 32 --cuda-graphs + metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-0.6b-f16-batch-prefill-decode-cuda-graphs.json - name: FP16 - Run Mistral-7B-Instruct-v0.3.fp16.gguf uses: ./.github/actions/run-inference @@ -358,6 +405,53 @@ jobs: configuration: standard metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0-6b-q8-standard.json + - name: Q8 - Run Qwen3-0.6B-Q8_0.gguf - Prefill-Decode + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-0.6B-Q8_0.gguf + model: Qwen3-0.6B + quantization: Q8_0 + configuration: prefill-decode + flags: --with-prefill-decode + metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0-6b-q8-prefill-decode.json + + - name: Q8 - Run Qwen3-0.6B-Q8_0.gguf - Batch-Prefill-Decode + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-0.6B-Q8_0.gguf + model: Qwen3-0.6B + quantization: Q8_0 + configuration: batch-prefill-decode + flags: --with-prefill-decode --batch-prefill-size 32 + metrics_file: ${{ runner.temp }}/metrics-${{ matrix.backend.name }}-qwen3-0-6b-q8-batch-prefill-decode.json + + # PTX-only: CUDA-graph variants + - name: PTX - Q8 - Run Qwen3-0.6B-Q8_0.gguf - Prefill-Decode-CUDA-Graphs + if: matrix.backend.name == 'ptx' + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-0.6B-Q8_0.gguf + model: Qwen3-0.6B + quantization: Q8_0 + configuration: prefill-decode-cuda-graphs + flags: --with-prefill-decode --cuda-graphs + metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-0-6b-q8-prefill-decode-cuda-graphs.json + + - name: PTX - Q8 - Run Qwen3-0.6B-Q8_0.gguf - Batch-Prefill-Decode-CUDA-Graphs + if: matrix.backend.name == 'ptx' + uses: ./.github/actions/run-inference + with: + backend: ${{ matrix.backend.name }} + model_file: Qwen3-0.6B-Q8_0.gguf + model: Qwen3-0.6B + quantization: Q8_0 + configuration: batch-prefill-decode-cuda-graphs + flags: --with-prefill-decode --batch-prefill-size 32 --cuda-graphs + metrics_file: ${{ runner.temp }}/metrics-ptx-qwen3-0-6b-q8-batch-prefill-decode-cuda-graphs.json + - name: Q8 - Run Phi-3-mini-4k-instruct-Q8_0.gguf uses: ./.github/actions/run-inference with: diff --git a/src/main/java/org/beehive/gpullama3/Options.java b/src/main/java/org/beehive/gpullama3/Options.java index 919f9751..274428b7 100644 --- a/src/main/java/org/beehive/gpullama3/Options.java +++ b/src/main/java/org/beehive/gpullama3/Options.java @@ -11,8 +11,8 @@ public record Options(Path modelPath, String prompt, String systemPrompt, String public Options { require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\""); - require(0 <= temperature, "Invalid argument: --temperature must be non-negative"); - require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); + require(Float.isNaN(temperature) || 0 <= temperature, "Invalid argument: --temperature must be non-negative"); + require(Float.isNaN(topp) || 0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); require(batchPrefillSize >= 1, "Invalid argument: --batch-prefill-size must be >= 1"); require(batchPrefillSize == 1 || withPrefillDecode, "Invalid argument: --batch-prefill-size requires --with-prefill-decode"); // Publish to system properties so TornadoVMMasterPlan and Llama read the right values @@ -44,8 +44,8 @@ public static void printUsage(PrintStream out) { out.println(" --prompt, -p input prompt"); out.println(" --system-prompt, -sp (optional) system prompt (Llama models)"); out.println(" --suffix suffix for fill-in-the-middle request (Codestral)"); - out.println(" --temperature, -temp temperature in [0,inf], default 0.1"); - out.println(" --top-p p value in top-p (nucleus) sampling in [0,1] default 0.95"); + out.println(" --temperature, -temp temperature in [0,inf], default: auto-detected from model family"); + out.println(" --top-p p value in top-p (nucleus) sampling in [0,1], default: auto-detected from model family"); out.println(" --seed random seed, default System.nanoTime()"); out.println(" --max-tokens, -n number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS); out.println(" --stream print tokens during generation; may cause encoding artifacts for non ASCII text, default true"); @@ -59,8 +59,8 @@ public static Options getDefaultOptions() { String prompt = "Tell me a story with Java"; // Hardcoded for testing String systemPrompt = null; String suffix = null; - float temperature = 0.1f; - float topp = 0.95f; + float temperature = Float.NaN; // resolved from model family after loading + float topp = Float.NaN; // resolved from model family after loading Path modelPath = null; long seed = System.nanoTime(); int maxTokens = DEFAULT_MAX_TOKENS; @@ -76,8 +76,8 @@ public static Options parseOptions(String[] args) { String prompt = "Tell me a story with Java"; // Hardcoded for testing String systemPrompt = null; String suffix = null; - float temperature = 0.1f; - float topp = 0.95f; + float temperature = Float.NaN; // resolved from model family after loading + float topp = Float.NaN; // resolved from model family after loading Path modelPath = null; long seed = System.nanoTime(); int maxTokens = DEFAULT_MAX_TOKENS; diff --git a/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java b/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java index 496d0761..bbb7c5ad 100644 --- a/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java +++ b/src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java @@ -122,7 +122,13 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp, } static Sampler createSampler(Model model, Options options) { - return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); + float temperature = Float.isNaN(options.temperature()) + ? (float) model.chatFormat().defaultTemperature() + : options.temperature(); + float topp = Float.isNaN(options.topp()) + ? (float) model.chatFormat().defaultTopP() + : options.topp(); + return selectSampler(model.configuration().vocabularySize(), temperature, topp, options.seed()); } /** diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java index d70625b9..c84ccf96 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java @@ -35,6 +35,18 @@ public Qwen3State(Configuration config, int batchsize) { this.tempKcur = new FloatArray(nEmbdHead); } + @Override + protected int batchQDim(Configuration config) { + Qwen3Configuration q3 = (Qwen3Configuration) config; + return q3.numberOfHeadsKey() * q3.numberOfHeads(); + } + + @Override + protected int batchKvDim(Configuration config) { + Qwen3Configuration q3 = (Qwen3Configuration) config; + return q3.numberOfHeadsValue() * q3.numberOfKeyValueHeads(); + } + @Override protected StateFields createStateFields(Configuration configuration) { StateFields fields = new StateFields(); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index 4e448508..0807b756 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -77,10 +77,10 @@ public abstract class State { public final HalfFloatArray embeddingXBatch; // B × dim (FP16 input) public final FloatArray wrapXBatch; // B × dim (live activations / Q8_0 dequant) public final HalfFloatArray wrapXbFP16Batch; // B × dim (RMSNorm output, FP16) - public final FloatArray wrapQBatch; // B × dim + public final FloatArray wrapQBatch; // B × qDim (Q projection) public final FloatArray wrapKBatch; // B × kvDim public final FloatArray wrapVBatch; // B × kvDim - public final FloatArray wrapXbBatch; // B × dim (attention output) + public final FloatArray wrapXbBatch; // B × qDim (attention output) public final FloatArray wrapHbBatch; // B × hiddenDim public final FloatArray attnScaleBatch; // B (per-token RMS scale, attn) public final FloatArray ffnScaleBatch; // B (per-token RMS scale, FFN) @@ -135,14 +135,15 @@ protected State(Configuration config, int batchsize) { int gpuBatchSize = Integer.getInteger("llama.prefillBatchSize", 1); if (gpuBatchSize > 1) { - int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int qDim = batchQDim(config); + int kvDim = batchKvDim(config); this.embeddingXBatch = new HalfFloatArray(gpuBatchSize * config.dim()); this.wrapXBatch = new FloatArray(gpuBatchSize * config.dim()); this.wrapXbFP16Batch = new HalfFloatArray(gpuBatchSize * config.dim()); - this.wrapQBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapQBatch = new FloatArray(gpuBatchSize * qDim); this.wrapKBatch = new FloatArray(gpuBatchSize * kvDim); this.wrapVBatch = new FloatArray(gpuBatchSize * kvDim); - this.wrapXbBatch = new FloatArray(gpuBatchSize * config.dim()); + this.wrapXbBatch = new FloatArray(gpuBatchSize * qDim); this.wrapHbBatch = new FloatArray(gpuBatchSize * config.hiddenDim()); this.attnScaleBatch = new FloatArray(gpuBatchSize); this.ffnScaleBatch = new FloatArray(gpuBatchSize); @@ -162,6 +163,16 @@ protected State(Configuration config, int batchsize) { } } + /** Q-projection output dimension per token (model specific: = dim for Llama; differs for Qwen3). */ + protected int batchQDim(Configuration config) { + return config.dim(); + } + + /** KV-cache dimension per token (model specific: = dim*nHeadKv/nHeads for Llama; differs for Qwen3). */ + protected int batchKvDim(Configuration config) { + return (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + } + // Abstract method - subclasses implement their specific allocation logic and sizes protected abstract StateFields createStateFields(Configuration config); diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index 827ad625..d23738a0 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -7,7 +7,9 @@ import org.beehive.gpullama3.tokenizer.Phi3Tokenizer; import org.beehive.gpullama3.tokenizer.Qwen3Tokenizer; +import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.Set; public interface ChatFormat { @@ -36,6 +38,179 @@ default ChatTokens chatTokens() { Set getStopTokens(); + /** + * Returns {@code true} when this chat format supports tool calling. + * Formats that implement tool-calling methods must override this to return {@code true}. + * Callers should check this before passing tool specifications to avoid hitting the + * default {@link UnsupportedOperationException} deep inside a format method. + */ + default boolean supportsToolCalling() { + return false; + } + + /** + * Returns plain text to append to the system message content when tools are available. + * Used by formats that inject tool definitions into the system message. + * + *

Formats that inject tools into the user message instead should override + * {@link #injectsToolsInUserMessage()}, {@link #toolSystemMessagePrefix()}, and + * {@link #toolFirstUserMessagePrefix(String)} rather than this method. + * + * @param toolsJson JSON array of tool definitions + */ + default String toolSystemPromptSuffix(String toolsJson) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Returns {@code true} when this format injects tool definitions into the + * first user message instead of the system message. + * + *

When this returns {@code true}, callers should: + *

    + *
  1. Prepend {@link #toolSystemMessagePrefix()} to the system message content.
  2. + *
  3. Prepend {@link #toolFirstUserMessagePrefix(String)} to the first user message.
  4. + *
+ * When {@code false} (default), callers should append {@link #toolSystemPromptSuffix} to + * the system message as before. + */ + default boolean injectsToolsInUserMessage() { + return false; + } + + /** + * Returns text to prepend to the system message content when tools are active + * and {@link #injectsToolsInUserMessage()} is {@code true}. + * Default: empty string (no prefix). + */ + default String toolSystemMessagePrefix() { + return ""; + } + + /** + * Returns the preamble to prepend to the first user message when + * {@link #injectsToolsInUserMessage()} is {@code true}. + * The preamble should include the tool definitions and usage instructions. + * + * @param toolsJson JSON array of tool definitions + */ + default String toolFirstUserMessagePrefix(String toolsJson) { + return ""; + } + + /** + * Re-encodes a prior assistant tool-call turn into the conversation token stream. + * Used when replaying multi-turn history that contains a previous tool call. + * + * @param toolCall the tool call to encode (name + raw arguments JSON) + */ + default List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Re-encodes a prior assistant turn that contained one or more tool calls as a + * single assistant message. Implementations must emit all calls inside one + * header/footer pair so the model does not see spurious assistant turn boundaries. + * + *

The default delegates to {@link #encodeToolCallAssistantTurn(ToolCallExtract)} + * for single-element lists and naively concatenates individual encodings for larger + * lists — formats that support batch tool calls should override this method. + * + * @param toolCalls the ordered list of tool calls from a single assistant turn + */ + default List encodeToolCallAssistantTurn(List toolCalls) { + if (toolCalls.isEmpty()) return List.of(); + if (toolCalls.size() == 1) return encodeToolCallAssistantTurn(toolCalls.get(0)); + List tokens = new ArrayList<>(); + for (ToolCallExtract tc : toolCalls) { + tokens.addAll(encodeToolCallAssistantTurn(tc)); + } + return tokens; + } + + /** + * Encodes a tool execution result message in the model-native format. + * + * @param toolCallId the ID of the originating tool call (may be ignored by some formats) + * @param toolName the name of the tool that was called + * @param result the result content string + */ + default List encodeToolResultTurn(String toolCallId, String toolName, String result) { + throw new UnsupportedOperationException("Tool calling not supported for: " + getClass().getSimpleName()); + } + + /** + * Detects and extracts a tool call from fully decoded model response text. + * Returns {@link Optional#empty()} when the response is a plain text answer. + * + * @param responseText the fully decoded response from the model + */ + default Optional extractToolCall(String responseText) { + return Optional.empty(); + } + + /** + * Extracts ALL tool calls from a response. Models may emit multiple + * {@code } blocks in a single turn (batch tool calls). + * The default delegates to {@link #extractToolCall} for formats that + * do not support batch calls. + * + * @param responseText the fully decoded response from the model + */ + default List extractAllToolCalls(String responseText) { + return extractToolCall(responseText).map(List::of).orElse(List.of()); + } + + /** + * Returns the recommended default temperature for this chat format. + * Used when the caller has not explicitly configured a temperature. + */ + default double defaultTemperature() { + return 0.7; + } + + /** + * Returns the recommended default top-p for this chat format. + * Used when the caller has not explicitly configured a top-p value. + */ + default double defaultTopP() { + return 0.9; + } + + /** + * Stop tokens to use when tool calling is enabled. + * Some models (LLaMA 3.1+) use a different end-of-turn token ({@code <|eom_id|>}) + * when emitting a tool call instead of a regular response. + */ + default Set getToolAwareStopTokens() { + return getStopTokens(); + } + + /** + * Returns {@code true} when this chat format has a controllable thinking/reasoning mode that + * {@link #encodeThinkingControl(boolean)} can toggle (e.g. Qwen3). Formats that return + * {@code false} (the default) have no reasoning phase to switch on or off, so the + * {@code enableThinking} flag is inert for them. Pure reasoning models that always think and + * offer no off-switch (e.g. DeepSeek-R1) also return {@code false}. + */ + default boolean supportsThinking() { + return false; + } + + /** + * Returns the tokens to append immediately after the assistant header in order to control + * the model's thinking/reasoning phase. Models that do not {@link #supportsThinking()} return + * an empty list (the default), so callers can invoke this unconditionally. + * + * @param enableThinking when {@code false}, returns the model-native primer that suppresses + * reasoning (e.g. Qwen3's pre-closed {@code } block); when {@code true}, + * returns an empty list so the model decides for itself. + */ + default List encodeThinkingControl(boolean enableThinking) { + return List.of(); + } + record ChatTokens(String tStartHeader, String tEndHeader, String tEndOfTurn, String tEndOfText, String tEndOfTextFim) { } diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index c98a72c9..f23e3c26 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -6,6 +6,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; public class LlamaChatFormat implements ChatFormat { @@ -17,6 +18,7 @@ public class LlamaChatFormat implements ChatFormat { protected final int endOfTurn; protected final int endOfText; protected final int endOfMessage; + protected final int pythonTag; protected final Set stopTokens; public LlamaChatFormat(Tokenizer tokenizer) { @@ -28,6 +30,7 @@ public LlamaChatFormat(Tokenizer tokenizer) { this.endOfTurn = specialTokens.get("<|eot_id|>"); this.endOfText = specialTokens.get("<|end_of_text|>"); this.endOfMessage = specialTokens.getOrDefault("<|eom_id|>", -1); // only in 3.1 + this.pythonTag = specialTokens.getOrDefault("<|python_tag|>", -1); // only in 3.1 this.stopTokens = Set.of(endOfText, endOfTurn); } @@ -71,4 +74,155 @@ public List encodeDialogPrompt(boolean appendAssistantTurn, Listfirst user message + * (the GGUF-embedded chat template has {@code tools_in_user_message = true} by default). + * The system message receives only an environment prefix; the tools and usage instructions + * go in the user turn. + */ + @Override + public boolean injectsToolsInUserMessage() { + return true; + } + + /** + * System-message prefix that signals tool availability to Llama 3.2. + * Matches the template's {@code "Environment: ipython\n"} line. + */ + @Override + public String toolSystemMessagePrefix() { + return "Environment: ipython\n\n"; + } + + /** + * Prepends tool definitions and usage instructions to the first user message, + * matching the Llama 3.2 GGUF chat template ({@code tools_in_user_message = true}). + * + *

Format mirrors: + *

+     * Given the following functions, please respond with a JSON for a function call
+     * with its proper arguments that best answers the given prompt.
+     *
+     * Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.
+     * Do not use variables.
+     *
+     * {toolsJson}
+     *
+     * 
+ */ + @Override + public String toolFirstUserMessagePrefix(String toolsJson) { + return "Given the following functions, please respond with a JSON for a function call " + + "with its proper arguments that best answers the given prompt.\n\n" + + "Respond in the format {\"name\": function name, \"parameters\": dictionary of " + + "argument name and its value}. Do not use variables.\n\n" + + toolsJson + "\n\n"; + } + + /** + * Re-encodes a prior assistant tool-call turn for multi-turn history using the + * Llama 3.2 native JSON format: {@code {"name":"…","parameters":{…}}<|eot_id|>}. + */ + @Override + public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); + // Preserve the <|python_tag|> prefix used by LLaMA 3.1/3.2 for tool calls so that + // replayed history looks identical to what the model originally generated. + if (pythonTag != -1) { + tokens.add(pythonTag); + } + String json = "{\"name\": \"" + toolCall.name() + "\", \"parameters\": " + toolCall.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeAsList(json)); + // LLaMA 3.1 ends tool-call turns with <|eom_id|>; fall back to <|eot_id|> for 3.2. + tokens.add(endOfMessage != -1 ? endOfMessage : endOfTurn); + return tokens; + } + + /** + * Encodes a tool result using the LLaMA "ipython" role. + * Format: {@code <|start_header_id|>ipython<|end_header_id|>\nresult<|eot_id|>} + */ + @Override + public List encodeToolResultTurn(String toolCallId, String toolName, String result) { + List tokens = new ArrayList<>(); + tokens.add(startHeader); + tokens.addAll(tokenizer.encodeAsList("ipython")); + tokens.add(endHeader); + tokens.addAll(tokenizer.encodeAsList("\n")); + tokens.addAll(tokenizer.encodeAsList(result)); + tokens.add(endOfTurn); + return tokens; + } + + /** + * Encodes multiple tool calls as a single assistant turn. + * For a single call, delegates to the existing single-call method (preserving the + * {@code <|python_tag|>} prefix on LLaMA 3.1). + * For multiple calls, LLaMA 3.1 prefixes each with {@code <|python_tag|>}; + * LLaMA 3.2 (no python_tag) uses {@code } blocks. + */ + @Override + public List encodeToolCallAssistantTurn(List toolCalls) { + if (toolCalls.isEmpty()) return List.of(); + if (toolCalls.size() == 1) return encodeToolCallAssistantTurn(toolCalls.get(0)); + List tokens = new ArrayList<>(encodeHeader(new Message(Role.ASSISTANT, ""))); + for (ToolCallExtract tc : toolCalls) { + String json = "{\"name\": \"" + tc.name() + "\", \"parameters\": " + tc.argumentsJson() + "}"; + if (pythonTag != -1) { + tokens.add(pythonTag); + tokens.addAll(tokenizer.encodeAsList(json + "\n")); + } else { + tokens.addAll(tokenizer.encodeAsList("\n" + json + "\n\n")); + } + } + tokens.add(endOfMessage != -1 ? endOfMessage : endOfTurn); + return tokens; + } + + /** + * Detects a tool call in the decoded response text. + * Supports LLaMA 3.1 (native {@code <|python_tag|>} + {@code "parameters"} key), + * LLaMA 3.2 ({@code "arguments"} key, tag often absent), and a raw-JSON fallback + * for smaller models. Delegates to {@link ToolCallParserUtils#parseToolCallResponse}. + */ + @Override + public Optional extractToolCall(String responseText) { + return ToolCallParserUtils.parseToolCallResponse(responseText); + } + + @Override + public List extractAllToolCalls(String responseText) { + return ToolCallParserUtils.parseAllToolCalls(responseText); + } + + /** + * Adds {@code <|eom_id|>} to the stop tokens when tools are enabled. + * LLaMA 3.1 ends tool-call turns with {@code <|eom_id|>} instead of {@code <|eot_id|>}. + */ + @Override + public Set getToolAwareStopTokens() { + if (endOfMessage != -1) { + return Set.of(endOfText, endOfTurn, endOfMessage); + } + return stopTokens; + } + } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java index b6d2e798..5d121c31 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java @@ -14,7 +14,6 @@ public class Qwen3ChatFormat implements ChatFormat { protected final int endHeader; protected final int endOfTurn; protected final int endOfText; - protected final int endOfMessage; protected final int endOfTextFim; protected final int imStart; // beginOfText protected final int imEnd; // endOfText @@ -28,13 +27,12 @@ public Qwen3ChatFormat(Qwen3Tokenizer tokenizer, ChatTokens chatTokens) { this.tokenizer = tokenizer; this.chatTokens = chatTokens; Map specialTokens = tokenizer.getSpecialTokens(); - this.beginOfText = specialTokens.getOrDefault("", -1); + this.beginOfText = -1; // Qwen3 has no BOS token; getBeginOfText() falls back to startHeader this.startHeader = specialTokens.getOrDefault(chatTokens.tStartHeader(), -1); this.endHeader = specialTokens.getOrDefault(chatTokens.tEndHeader(), -1); this.endOfTurn = specialTokens.getOrDefault(chatTokens.tEndOfTurn(), -1); this.endOfText = specialTokens.getOrDefault(chatTokens.tEndOfText(), -1); this.endOfTextFim = specialTokens.getOrDefault(chatTokens.tEndOfTextFim(), -1); - this.endOfMessage = specialTokens.getOrDefault("", -1); // Use default value if key not found this.imStart = startHeader; this.imEnd = endHeader; @@ -129,4 +127,153 @@ public Set getStopTokens() { return stopTokens; } + + @Override + public double defaultTemperature() { + return 0.8; + } + + @Override + public double defaultTopP() { + return 0.9; + } + + /** + * Genuine Qwen3 exposes the {@code enable_thinking} template switch and so supports thinking + * control. DeepSeek-R1 is routed through this same format (detected by the absence of an + * {@code <|im_end|>} token) but is a pure reasoning model with no off-switch, so it reports + * {@code false} and is left to always reason. + */ + @Override + public boolean supportsThinking() { + return imEnd != -1; + } + + /** + * Qwen3 thinking control. When thinking is disabled, primes a pre-closed + * {@code \n\n\n\n} block right after the assistant header so the model skips + * its reasoning phase — matching the {@code enable_thinking=false} branch of the official + * Qwen3 chat template. When enabled (or for DeepSeek-R1, which cannot disable thinking), + * returns nothing and lets the model reason on its own. + * + *

The {@code }/{@code } markers are emitted as their canonical + * single token ids (not ordinary BPE sub-pieces): the tokenizer strips them from its special + * map so reasoning renders as text, but the model only recognises the closed block — and thus + * actually skips reasoning — when it sees the real control tokens it was trained on. + */ + @Override + public List encodeThinkingControl(boolean enableThinking) { + if (enableThinking || !supportsThinking()) { + return List.of(); + } + int thinkStart = tokenizer.getThinkStartToken(); + int thinkEnd = tokenizer.getThinkEndToken(); + if (thinkStart == -1 || thinkEnd == -1) { + // GGUF without dedicated think tokens — fall back to ordinary text encoding. + return tokenizer.encodeOrdinaryAsList("\n\n\n\n"); + } + List tokens = new ArrayList<>(); + tokens.add(thinkStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("\n\n")); + tokens.add(thinkEnd); + tokens.addAll(tokenizer.encodeOrdinaryAsList("\n\n")); + return tokens; + } + + // ── Tool calling ────────────────────────────────────────────────────────── + + @Override + public boolean supportsToolCalling() { + return true; + } + + /** + * Qwen3 tool calling system prompt suffix. + * Appended to the system message; instructs the model to wrap tool calls in + * {@code } XML tags. + */ + @Override + public String toolSystemPromptSuffix(String toolsJson) { + return "\n\n# Tools\n\n" + + "You may call one or more functions to assist with the user query.\n\n" + + "You are provided with function signatures within XML tags:\n" + + "\n" + + toolsJson + + "\n\n\n" + + "For each function call, return a json object with function name and arguments " + + "within XML tags:\n" + + "\n" + + "{\"name\": , \"arguments\": }\n" + + ""; + } + + /** + * Re-encodes a prior assistant tool-call turn for multi-turn history. + * Format: {@code <|im_start|>assistant\n\nJSON\n<|im_end|>} + */ + @Override + public List encodeToolCallAssistantTurn(ToolCallExtract toolCall) { + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("assistant\n")); + String json = "{\"name\":\"" + toolCall.name() + "\",\"arguments\":" + toolCall.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeOrdinaryAsList("\n" + json + "\n")); + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + + /** + * Encodes multiple tool calls as a single assistant turn: one {@code <|im_start|>assistant} + * header, all {@code } blocks concatenated, then {@code <|im_end|>}. + * For a single call, delegates to the existing single-call method. + */ + @Override + public List encodeToolCallAssistantTurn(List toolCalls) { + if (toolCalls.isEmpty()) return List.of(); + if (toolCalls.size() == 1) return encodeToolCallAssistantTurn(toolCalls.get(0)); + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("assistant\n")); + for (ToolCallExtract tc : toolCalls) { + String json = "{\"name\":\"" + tc.name() + "\",\"arguments\":" + tc.argumentsJson() + "}"; + tokens.addAll(tokenizer.encodeOrdinaryAsList("\n" + json + "\n")); + } + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + + /** + * Encodes a tool result in the native Qwen3 format: a {@code user} turn whose content is the + * result wrapped in {@code } tags, matching the official Qwen3 + * chat template. (Qwen3 has no dedicated "tool" role — results are delivered as user turns.) + * Format: {@code <|im_start|>user\n\nresult\n<|im_end|>} + */ + @Override + public List encodeToolResultTurn(String toolCallId, String toolName, String result) { + List tokens = new ArrayList<>(); + tokens.add(imStart); + tokens.addAll(tokenizer.encodeOrdinaryAsList("user\n\n" + result + "\n")); + if (imEnd != -1) { + tokens.add(imEnd); + } + return tokens; + } + + /** + * Detects a tool call enclosed in {@code } tags. + * Delegates to {@link ToolCallParserUtils#parseToolCallResponse}. + */ + @Override + public Optional extractToolCall(String responseText) { + return ToolCallParserUtils.parseToolCallResponse(responseText); + } + + @Override + public List extractAllToolCalls(String responseText) { + return ToolCallParserUtils.parseAllToolCalls(responseText); + } } diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java new file mode 100644 index 00000000..b5f82c51 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallExtract.java @@ -0,0 +1,20 @@ +package org.beehive.gpullama3.model.format; + +import java.util.Optional; + +/** + * Represents a single tool call extracted from a model response. + * Contains the raw strings — JSON parsing of arguments is left to the caller. + * + * @param name the tool/function name to invoke + * @param argumentsJson the arguments as a JSON object string, e.g. {"location":"Boston"} + * @param id optional tool call ID parsed from the model response; callers that + * generate IDs themselves (e.g. Ollama-style "call_XXXXXXXX") may pass + * {@link Optional#empty()} and let the consumer generate one + */ +public record ToolCallExtract(String name, String argumentsJson, Optional id) { + + public ToolCallExtract(String name, String argumentsJson) { + this(name, argumentsJson, Optional.empty()); + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java new file mode 100644 index 00000000..0691b4e7 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/ToolCallParserUtils.java @@ -0,0 +1,205 @@ +package org.beehive.gpullama3.model.format; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** + * Pure-string tool-call extraction for Llama and Qwen3 response formats. + * + * All methods are stateless and do not require any model or tokenizer instance, + * making them directly unit-testable. + */ +public final class ToolCallParserUtils { + + private ToolCallParserUtils() {} + + /** + * Extracts a single tool call from a model response text. + * + * Recognised formats (in priority order): + * 1. {@code <|python_tag|>{…}} — LLaMA 3.1 native + * 2. {@code } — LLaMA 3.2 and Qwen3 (closed or unclosed) + * 3. Raw JSON object optionally inside markdown code fences — fallback + * + * Both {@code "parameters"} and {@code "arguments"} are tried as the argument key, + * covering LLaMA 3.1/3.2 and Qwen3 variants transparently. + */ + public static Optional parseToolCallResponse(String responseText) { + // 1. Native LLaMA 3.1 format: <|python_tag|>{...} + int idx = responseText.indexOf("<|python_tag|>"); + if (idx != -1) { + String json = responseText.substring(idx + "<|python_tag|>".length()).strip(); + return parseToolCallJson(json); + } + + // 2. LLaMA 3.2 format: ... + int tcStart = responseText.indexOf(""); + int tcEnd = responseText.lastIndexOf(""); + if (tcStart != -1 && tcEnd != -1 && tcEnd > tcStart) { + String json = responseText.substring(tcStart + "".length(), tcEnd).strip(); + return parseToolCallJson(json); + } + // 2b. Unclosed — model stopped (eot_id / eom_id) before writing the closing tag + if (tcStart != -1 && tcEnd == -1) { + String json = responseText.substring(tcStart + "".length()).strip(); + return parseToolCallJson(json); + } + + // 3. Fallback: raw JSON, possibly inside markdown code fences + String stripped = stripMarkdownFences(responseText.strip()); + if (stripped.startsWith("{")) { + return parseToolCallJson(stripped); + } + + return Optional.empty(); + } + + /** + * Parses a tool call JSON object extracted from a {@code } block or raw JSON. + * Accepts {@code {"name":…,"parameters":{…}}}, {@code {"function":…,"parameters":{…}}}, + * and {@code {"name":…,"arguments":{…}}} — covering both LLaMA and Qwen3 variants. + */ + private static Optional parseToolCallJson(String json) { + String name = extractStringValue(json, "name"); + if (name == null) { + name = extractStringValue(json, "function"); + } + if (name == null) return Optional.empty(); + + String argsJson = extractNestedObject(json, "parameters"); + if (argsJson == null) argsJson = extractNestedObject(json, "arguments"); + if (argsJson == null) argsJson = "{}"; + + return Optional.of(new ToolCallExtract(name, argsJson)); + } + + // Batch extraction + + /** + * Extracts ALL tool calls from a response that may contain multiple + * {@code } blocks (Llama 3.2 and Qwen3 batch calls). + * + * Falls back to the raw-JSON single-call path if no tags are found. + * Returns an empty list when the response contains no tool calls. + */ + public static List parseAllToolCalls(String responseText) { + List calls = new ArrayList<>(); + + // <|python_tag|> (Llama 3.1) — single call by definition + int pythonIdx = responseText.indexOf("<|python_tag|>"); + if (pythonIdx != -1) { + parseToolCallJson(responseText.substring(pythonIdx + "<|python_tag|>".length()).strip()) + .ifPresent(calls::add); + return calls; + } + + // Scan for all blocks + int searchFrom = 0; + while (true) { + int start = responseText.indexOf("", searchFrom); + if (start == -1) break; + int end = responseText.indexOf("", start); + String json; + if (end != -1) { + json = responseText.substring(start + "".length(), end).strip(); + searchFrom = end + "".length(); + } else { + // Unclosed tag — model stopped before writing the closing tag + json = responseText.substring(start + "".length()).strip(); + searchFrom = responseText.length(); + } + parseToolCallJson(json).ifPresent(calls::add); + if (end == -1) break; + } + + // Raw JSON fallback (no tags at all) + if (calls.isEmpty()) { + String stripped = stripMarkdownFences(responseText.strip()); + if (stripped.startsWith("{")) { + parseToolCallJson(stripped).ifPresent(calls::add); + } + } + + return calls; + } + + // Shared helpers + + /** Strips surrounding markdown code fences (```…```) if present. */ + public static String stripMarkdownFences(String text) { + if (!text.startsWith("```")) return text; + int firstNewline = text.indexOf('\n'); + if (firstNewline == -1) return text; + String body = text.substring(firstNewline + 1); + if (body.endsWith("```")) body = body.substring(0, body.length() - 3).stripTrailing(); + return body.strip(); + } + + /** + * Extracts the string value for {@code "key": ""} from a JSON object. + * Tolerates whitespace around {@code :} and correctly skips escaped quotes ({@code \"}) + * inside the value, so multi-line code strings with embedded {@code "} are returned intact. + */ + private static String extractStringValue(String json, String key) { + String marker = "\"" + key + "\""; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int colonIdx = json.indexOf(':', markerIdx + marker.length()); + if (colonIdx == -1) return null; + int quoteStart = json.indexOf('"', colonIdx + 1); + if (quoteStart == -1) return null; + // Scan for the closing quote, honouring backslash escapes + int i = quoteStart + 1; + while (i < json.length()) { + char c = json.charAt(i); + if (c == '\\') { + i += 2; // skip escape sequence (e.g. \", \\, \n) + } else if (c == '"') { + break; + } else { + i++; + } + } + if (i >= json.length()) return null; + return json.substring(quoteStart + 1, i); + } + + /** + * Extracts the JSON object value for {@code "key": {…}} using brace-counting. + * Handles nested objects and tolerates whitespace around {@code :}. + * + *

Brace counting is string-aware: {@code {} and } characters appearing inside + * JSON string literals (e.g. a {@code "code"} argument containing Java source) do not affect + * the depth counter, and {@code \"} escapes inside strings are skipped. This keeps argument + * objects whose string values contain braces intact. + */ + private static String extractNestedObject(String json, String key) { + String marker = "\"" + key + "\""; + int markerIdx = json.indexOf(marker); + if (markerIdx == -1) return null; + int colonIdx = json.indexOf(':', markerIdx + marker.length()); + if (colonIdx == -1) return null; + int braceStart = json.indexOf('{', colonIdx + 1); + if (braceStart == -1) return null; + int depth = 0; + boolean inString = false; + for (int i = braceStart; i < json.length(); i++) { + char c = json.charAt(i); + if (inString) { + if (c == '\\') { + i++; // skip the escaped character (e.g. \", \\, \n) + } else if (c == '"') { + inString = false; + } + } else if (c == '"') { + inString = true; + } else if (c == '{') { + depth++; + } else if (c == '}') { + if (--depth == 0) return json.substring(braceStart, i + 1); + } + } + return null; // unbalanced + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java index d178be7c..5904316f 100644 --- a/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java +++ b/src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java @@ -2,6 +2,8 @@ import org.beehive.gpullama3.inference.InferenceCore; import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.InferenceEngineWithBatchPrefillDecode; +import org.beehive.gpullama3.inference.InferenceEngineWithPrefillDecode; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.Qwen3State; import org.beehive.gpullama3.inference.state.State; @@ -88,10 +90,10 @@ public List generateTokens(State state, int startPosition, List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) { - throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Qwen3"); + return InferenceEngineWithBatchPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } if (WITH_PREFILL_DECODE) { - throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Qwen3"); + return InferenceEngineWithPrefillDecode.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); } diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java index 077dd536..ea122d44 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/Qwen3Tokenizer.java @@ -25,6 +25,11 @@ public class Qwen3Tokenizer implements Tokenizer { private final Map, Integer> merges; private final Map specialTokens; private final int[] tokenTypes; + /** Canonical {@code }/{@code } token ids (or -1), captured before they are + * removed from {@link #specialTokens} so reasoning renders as text. Used to prime the + * thinking-disabled control block with the tokens the model was actually trained on. */ + private final int thinkStartToken; + private final int thinkEndToken; /** buffer to store incomplete UTF-8 sequence */ private final byte[] bufUtf8 = new byte[4]; /** index in UTF-8 buffer */ @@ -59,6 +64,10 @@ public Qwen3Tokenizer(Map metadata, Vocabulary vocabulary, boole i -> specialTokensList.get(i), i -> baseTokens + i) ); + // Capture the canonical think-control token ids BEFORE removing them from the special + // map (they are removed so reasoning text renders verbatim during decode). + this.thinkStartToken = specialTokens.getOrDefault("", -1); + this.thinkEndToken = specialTokens.getOrDefault("", -1); specialTokens.remove(""); specialTokens.remove(""); @@ -145,6 +154,16 @@ public Map getSpecialTokens() { return specialTokens; } + /** Canonical {@code } token id, or {@code -1} if this GGUF has no think tokens. */ + public int getThinkStartToken() { + return thinkStartToken; + } + + /** Canonical {@code } token id, or {@code -1} if this GGUF has no think tokens. */ + public int getThinkEndToken() { + return thinkEndToken; + } + @Override public boolean isSpecialToken(int tokenIndex) { return specialTokens.containsValue(tokenIndex); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java index ffca1e4a..26b2f2b0 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Qwen3Kernels.java @@ -894,5 +894,325 @@ public static void fusedQKRmsNorm( } } } + + // ── Batch prefill kernels ──────────────────────────────────────────────── + + /** + * Batched fused Q/K/V projection for Qwen3 GQA (FP16 weights, FP16 input). + * + *

Like {@code TransformerBatchPrefillKernels.batchedFusedQKVMatmul} but uses + * separate {@code qDim} (Q output rows) and {@code kvDim} (K/V output rows). + * Row layout: [0, qDim) → Q; [qDim, qDim+kvDim) → K; [qDim+kvDim, qDim+2*kvDim) → V. + * Q output stride per batch = qDim; K/V output stride = kvDim.

+ * + * Worker: B*(qDim+2*kvDim) workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedQKVMatmulFP16( + KernelContext context, + HalfFloatArray xbFP16Batch, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + HalfFloatArray wq, + HalfFloatArray wk, + HalfFloatArray wv, + int inputDim, + int qDim, + int kvDim, + int localWorkGroupSize) { + + int groupId = context.globalIdx / localWorkGroupSize; + int localId = context.localIdx; + int totalRows = qDim + 2 * kvDim; + int batchIdx = groupId / totalRows; + int rowIdx = groupId % totalRows; + int inputOff = batchIdx * inputDim; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowIdx < qDim) { + int rowOff = rowIdx * inputDim; + float partial = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partial += wq.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapQBatch.set(batchIdx * qDim + rowIdx, localSum[0]); + + } else if (rowIdx < qDim + kvDim) { + int kRow = rowIdx - qDim; + int rowOff = kRow * inputDim; + float partial = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partial += wk.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapKBatch.set(batchIdx * kvDim + kRow, localSum[0]); + + } else { + int vRow = rowIdx - qDim - kvDim; + int rowOff = vRow * inputDim; + float partial = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + partial += wv.get(rowOff + j).getFloat32() * xbFP16Batch.get(inputOff + j).getFloat32(); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapVBatch.set(batchIdx * kvDim + vRow, localSum[0]); + } + } + + /** + * Batched fused Q/K/V projection for Qwen3 GQA (Q8_0 weights, FP32 input). + * + * Worker: B*(qDim+2*kvDim) workgroups × localWorkGroupSize threads. + */ + public static void batchedFusedQKVMatmulQ8_0( + KernelContext context, + FloatArray wrapXbBatch, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + ByteArray wq, + ByteArray wk, + ByteArray wv, + int inputDim, + int qDim, + int kvDim, + int localWorkGroupSize) { + + int groupId = context.globalIdx / localWorkGroupSize; + int localId = context.localIdx; + int totalRows = qDim + 2 * kvDim; + int batchIdx = groupId / totalRows; + int rowIdx = groupId % totalRows; + int inputOff = batchIdx * inputDim; + + final int blockSize = 32; + final int Q8_0_BLOCK_BYTES = 34; + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + + if (rowIdx < qDim) { + int blocksPerRow = (inputDim + blockSize - 1) / blockSize; + int rowBlockOff = rowIdx * blocksPerRow; + float partial = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlock = j % blockSize; + int blockByteOff = (rowBlockOff + blockIdx) * Q8_0_BLOCK_BYTES; + HalfFloat sc = wq.getHalfFloat(blockByteOff); + byte q8 = wq.get(blockByteOff + 2 + withinBlock); + partial += ((float) q8 * sc.getFloat32()) * wrapXbBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapQBatch.set(batchIdx * qDim + rowIdx, localSum[0]); + + } else if (rowIdx < qDim + kvDim) { + int kRow = rowIdx - qDim; + int blocksPerRow = (inputDim + blockSize - 1) / blockSize; + int rowBlockOff = kRow * blocksPerRow; + float partial = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlock = j % blockSize; + int blockByteOff = (rowBlockOff + blockIdx) * Q8_0_BLOCK_BYTES; + HalfFloat sc = wk.getHalfFloat(blockByteOff); + byte q8 = wk.get(blockByteOff + 2 + withinBlock); + partial += ((float) q8 * sc.getFloat32()) * wrapXbBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapKBatch.set(batchIdx * kvDim + kRow, localSum[0]); + + } else { + int vRow = rowIdx - qDim - kvDim; + int blocksPerRow = (inputDim + blockSize - 1) / blockSize; + int rowBlockOff = vRow * blocksPerRow; + float partial = 0.0f; + for (int j = localId; j < inputDim; j += localWorkGroupSize) { + int blockIdx = j / blockSize; + int withinBlock = j % blockSize; + int blockByteOff = (rowBlockOff + blockIdx) * Q8_0_BLOCK_BYTES; + HalfFloat sc = wv.getHalfFloat(blockByteOff); + byte q8 = wv.get(blockByteOff + 2 + withinBlock); + partial += ((float) q8 * sc.getFloat32()) * wrapXbBatch.get(inputOff + j); + } + localSum[localId] = partial; + context.localBarrier(); + for (int s = localWorkGroupSize / 2; s > 0; s >>= 1) { + if (localId < s) localSum[localId] += localSum[localId + s]; + context.localBarrier(); + } + if (localId == 0) wrapVBatch.set(batchIdx * kvDim + vRow, localSum[0]); + } + } + + /** + * Batched fused Q/K RMSNorm for Qwen3 GQA. + * + *

Workgroup layout: B*(nHeads+nHeadKv) groups × nEmbdHead local threads. + * Groups [0, B*nHeads) normalize Q; groups [B*nHeads, B*(nHeads+nHeadKv)) normalize K. + * groupIdx = batchIdx*(nHeads+nHeadKv) + headSlot.

+ * + * Worker: B*(nHeads+nHeadKv) workgroups × nEmbdHead threads. + */ + public static void batchedFusedQKRmsNorm( + KernelContext context, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray qWeights, + FloatArray kWeights, + int nHeads, + int nHeadKv, + int nEmbdHead, + int qDim, + int kvDim, + float rmsNormEps) { + + int groupId = context.globalIdx / nEmbdHead; + int localId = context.localIdx; + int localSize = context.localGroupSizeX; + int totalHeadsPerBatch = nHeads + nHeadKv; + + int batchIdx = groupId / totalHeadsPerBatch; + int headSlot = groupId % totalHeadsPerBatch; + + float[] localSum = context.allocateFloatLocalArray(nEmbdHead); + + if (headSlot < nHeads) { + // Q head + int headOffset = batchIdx * qDim + headSlot * nEmbdHead; + float partialSum = 0.0f; + for (int i = localId; i < nEmbdHead; i += localSize) { + float val = wrapQBatch.get(headOffset + i); + partialSum += val * val; + } + localSum[localId] = partialSum; + context.localBarrier(); + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) localSum[localId] += localSum[localId + stride]; + context.localBarrier(); + } + float ss = localSum[0] / nEmbdHead + rmsNormEps; + ss = 1.0f / TornadoMath.sqrt(ss); + context.localBarrier(); + for (int i = localId; i < nEmbdHead; i += localSize) { + wrapQBatch.set(headOffset + i, qWeights.get(i) * ss * wrapQBatch.get(headOffset + i)); + } + } else { + // K head + int kHeadIdx = headSlot - nHeads; + int headOffset = batchIdx * kvDim + kHeadIdx * nEmbdHead; + float partialSum = 0.0f; + for (int i = localId; i < nEmbdHead; i += localSize) { + float val = wrapKBatch.get(headOffset + i); + partialSum += val * val; + } + localSum[localId] = partialSum; + context.localBarrier(); + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) localSum[localId] += localSum[localId + stride]; + context.localBarrier(); + } + float ss = localSum[0] / nEmbdHead + rmsNormEps; + ss = 1.0f / TornadoMath.sqrt(ss); + context.localBarrier(); + for (int i = localId; i < nEmbdHead; i += localSize) { + wrapKBatch.set(headOffset + i, kWeights.get(i) * ss * wrapKBatch.get(headOffset + i)); + } + } + } + + /** + * Batched fused RoPE rotation + KV cache write for Qwen3. + * + *

Like {@code TransformerBatchPrefillKernels.batchedRopeWithKVCache} but uses + * Qwen3 RoPE theta (1 000 000) and a separate {@code qDim} for the Q stride.

+ * + *

globalIdx = batchIdx*(qDim/2) + pairIdx. + * K rotation is applied only when pairIdx < kvDim/2.

+ * + * Worker: B*(qDim/2) global threads, localSize tuned like Llama RoPE. + */ + public static void batchedRopeWithKVCacheQwen3( + KernelContext context, + IntArray batchStartPosHolder, + FloatArray wrapQBatch, + FloatArray wrapKBatch, + FloatArray wrapVBatch, + FloatArray wrapKeyCache, + FloatArray wrapValueCache, + int kvDim, + int nEmbdHead, + int layerIndex, + int contextLength, + int qDim) { + + int globalIdx = context.globalIdx; + int halfQDim = qDim / 2; + int batchIdx = globalIdx / halfQDim; + int pairIdx = globalIdx % halfQDim; + + int pos = batchStartPosHolder.get(0) + batchIdx; + + // Qwen3 uses split-half RoPE: pair element ic with ic + nEmbdHead/2 within each head. + int halfEmbdHead = nEmbdHead / 2; + int ic = pairIdx % halfEmbdHead; + int headIdx = pairIdx / halfEmbdHead; + + float freq = 1.0f / TornadoMath.pow(1000000.0f, 2.0f * ic / (float) nEmbdHead); + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Rotate Q (split-half pairs within each head) + int qHeadBase = batchIdx * qDim + headIdx * nEmbdHead; + float v0q = wrapQBatch.get(qHeadBase + ic); + float v1q = wrapQBatch.get(qHeadBase + ic + halfEmbdHead); + wrapQBatch.set(qHeadBase + ic, v0q * fcr - v1q * fci); + wrapQBatch.set(qHeadBase + ic + halfEmbdHead, v0q * fci + v1q * fcr); + + // Rotate K and write K,V to cache (only for KV pairs) + if (pairIdx < kvDim / 2) { + int kHeadIdx = pairIdx / halfEmbdHead; + int kHeadBase = batchIdx * kvDim + kHeadIdx * nEmbdHead; + float v0k = wrapKBatch.get(kHeadBase + ic); + float v1k = wrapKBatch.get(kHeadBase + ic + halfEmbdHead); + float rotK0 = v0k * fcr - v1k * fci; + float rotK1 = v0k * fci + v1k * fcr; + wrapKBatch.set(kHeadBase + ic, rotK0); + wrapKBatch.set(kHeadBase + ic + halfEmbdHead, rotK1); + + int cacheOff = layerIndex * contextLength * kvDim + pos * kvDim + kHeadIdx * nEmbdHead; + wrapKeyCache.set(cacheOff + ic, rotK0); + wrapKeyCache.set(cacheOff + ic + halfEmbdHead, rotK1); + wrapValueCache.set(cacheOff + ic, wrapVBatch.get(kHeadBase + ic)); + wrapValueCache.set(cacheOff + ic + halfEmbdHead, wrapVBatch.get(kHeadBase + ic + halfEmbdHead)); + } + } } // @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 0e8b21b0..b2646351 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -24,7 +24,7 @@ public class Qwen3FP16FFNLayers extends AbstractTransformerLayerTaskGraphs { // Typed reference to Qwen3-specific state - private final Qwen3State qwen3State; + protected final Qwen3State qwen3State; // Qwen3-specific GQA parameters private final int nHeadKv; private final int nEmbdHeadK; @@ -192,7 +192,12 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var unifiedLayer = new TaskGraph(taskGraphName); // === Data Setup === - unifiedLayer.consumeFromDevice(state.wrapX); + String wrapXSrc = predecessorGraphName(layerIndex); + if (wrapXSrc != null) { + unifiedLayer.consumeFromDevice(wrapXSrc, state.wrapX); + } else { + unifiedLayer.consumeFromDevice(state.wrapX); + } unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Attention weights weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS norm weights @@ -356,12 +361,27 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { config.hiddenDim(), // input dim config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC) - .persistOnDevice(qwen3State.wrapX); + .persistOnDevice(qwen3State.wrapX, qwen3State.wrapKeyCache, qwen3State.wrapValueCache); return unifiedLayer; } // @formatter:on + /** + * Returns the explicit predecessor graph name for consumeFromDevice. + * + *

The single-token plan receives {@code wrapX} (and relays all persisted buffers, + * including the KV cache) from a named predecessor graph: the activation graph for + * layer 0, the previous layer graph otherwise. The no-arg consume form looks up the + * current graph's name as the source key, which never matches in interpreter + * mode, so the persisted KV-cache buffer is not propagated and gets re-allocated every + * decode token — exhausting the device-memory pool (OOM) on long generations. + * Decode subclasses override this with their own predecessor names.

+ */ + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "activationUpdate" : "layer_" + (layerIndex - 1); + } + /** * Configure data transfers for first and subsequent layers */ @@ -377,8 +397,12 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye qwen3State.wrapKeyCache, qwen3State.wrapValueCache, // qwen3State.wrapAtt, qwen3State.wrapHb ); } else { - // Subsequent layers: Consume data from previous layer - unifiedLayer.consumeFromDevice(context, qwen3State.wrapXb, qwen3State.wrapXb2, // + // Subsequent layers: consume from the previous layer graph BY NAME. + // The no-arg consumeFromDevice form uses the current graph's own name as the + // source key, which never matches the predecessor in interpreter mode, so the + // persisted KV cache is not propagated and is re-allocated every token (OOM). + String pred = "layer_" + (layerIndex - 1); + unifiedLayer.consumeFromDevice(pred, context, qwen3State.wrapXb, qwen3State.wrapXb2, // qwen3State.wrapQ, qwen3State.wrapK, // qwen3State.wrapV, qwen3State.wrapKeyCache, // qwen3State.wrapValueCache, qwen3State.wrapAtt, // diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersDecode.java new file mode 100644 index 00000000..83c4aefc --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersDecode.java @@ -0,0 +1,59 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Decode transformer-layer TaskGraphs for the unified batched prefill-decode plan (Qwen3 FP16). + * + *

Layer 0: KV cache is consumed from "decodeActivation" (already allocated by the batch + * prefill phase). Working buffers get FIRST_EXECUTION allocation. Layers 1+: all consumed + * objects use the explicit predecessor name to satisfy TornadoVM interpreter mode.

+ * + *

Qwen3FP16FFNLayers does not use wrapXbFP16 in any task, so it is excluded.

+ */ +public class Qwen3FP16FFNLayersDecode extends Qwen3FP16FFNLayers { + + public Qwen3FP16FFNLayersDecode(String taskGraph, Qwen3State state, + Qwen3TornadoWeights weights, Qwen3Configuration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + qwen3State.wrapXb, qwen3State.wrapXb2, + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, + qwen3State.wrapAtt, qwen3State.wrapHb); + // KV cache already allocated by batch prefill; relay from decode activation graph. + layer.consumeFromDevice("decodeActivation", + qwen3State.wrapKeyCache, qwen3State.wrapValueCache); + } else { + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + qwen3State.wrapXb, qwen3State.wrapXb2, + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, + qwen3State.wrapAtt, qwen3State.wrapHb, + qwen3State.positionHolder, + qwen3State.temp, qwen3State.tempFFN); + } + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersPrefillDecode.java new file mode 100644 index 00000000..bba3d49e --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/decode/Qwen3FP16FFNLayersPrefillDecode.java @@ -0,0 +1,48 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.decode; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Decode transformer-layer TaskGraphs for the single-token prefill/decode plan (Qwen3 FP16). + * + *

Layer 0 delegates to the base-class {@link Qwen3FP16FFNLayers#configureLayerDataTransfers} + * which allocates wrapKeyCache/wrapValueCache with FIRST_EXECUTION. Layers 1+ consume all + * live buffers from the explicit predecessor graph to satisfy TornadoVM interpreter mode.

+ * + *

Note: Qwen3FP16FFNLayers does not use wrapXbFP16 in any kernel task, so it is + * intentionally excluded from the consume list.

+ */ +public class Qwen3FP16FFNLayersPrefillDecode extends Qwen3FP16FFNLayers { + + public Qwen3FP16FFNLayersPrefillDecode(String taskGraph, Qwen3State state, + Qwen3TornadoWeights weights, Qwen3Configuration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + return super.configureLayerDataTransfers(layer, 0); + } + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + qwen3State.wrapXb, qwen3State.wrapXb2, + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, + qwen3State.wrapAtt, qwen3State.wrapHb, + qwen3State.positionHolder); + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/Qwen3FP16LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/Qwen3FP16LayersBatchPrefill.java new file mode 100644 index 00000000..85b116d2 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/prefill/Qwen3FP16LayersBatchPrefill.java @@ -0,0 +1,253 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +/** + * Batched-prefill transformer-layer TaskGraphs for the Qwen3 FP16 unified batched prefill-decode plan. + * + *

Mirrors {@link LlamaFP16LayersBatchPrefill} but adapts to Qwen3's GQA layout and + * Qwen3-specific kernels (fused Q/K RMSNorm, RoPE theta = 1 000 000). Avoids any calls to + * {@code Qwen3Configuration.headSize()}, {@code kvDim()}, or {@code kvMul()} which throw.

+ */ +public class Qwen3FP16LayersBatchPrefill implements BatchPrefillTransformerLayerTaskGraphs { + + static final int LOCAL_WORK_GROUP_SIZE = 32; + + private final Qwen3State state; + private final Qwen3TornadoWeights weights; + private final Qwen3Configuration config; + private final KernelContext context = new KernelContext(); + private final int batchSize; + private final int nHeadKv; + private final int nEmbdHeadK; + private final int nEmbdHeadV; + private final int nEmbdHead; + private final int qDim; + private final int kvDim; + private final int gqa; + private final List layerITGs; + private String lastLayerTaskGraphID; + + public Qwen3FP16LayersBatchPrefill(Qwen3State state, Qwen3TornadoWeights weights, + Qwen3Configuration config, int batchSize) { + this.state = state; + this.weights = weights; + this.config = config; + this.batchSize = batchSize; + this.nHeadKv = config.numberOfKeyValueHeads(); + this.nEmbdHeadK = config.numberOfHeadsKey(); + this.nEmbdHeadV = config.numberOfHeadsValue(); + this.nEmbdHead = nEmbdHeadV; + this.qDim = nEmbdHeadK * config.numberOfHeads(); + this.kvDim = nEmbdHeadV * nHeadKv; + this.gqa = config.numberOfHeads() / nHeadKv; + this.layerITGs = IntStream.range(0, config.numberOfLayers()) + .mapToObj(this::createBatchPrefillLayerTaskGraph) + .map(TaskGraph::snapshot) + .toList(); + } + + // @formatter:off + private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { + String graphName = "batchPrefillLayer_" + layerIndex; + if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName; + + TaskGraph layer = new TaskGraph(graphName); + int dim = config.dim(); + int hidDim = config.hiddenDim(); + + // ── Data Transfers ───────────────────────────────────────────────────── + if (layerIndex == 0) { + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.attnScaleBatch, state.ffnScaleBatch, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapXbBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache); + layer.consumeFromDevice("prefillActivation", state.wrapXBatch); + } else { + String pred = "batchPrefillLayer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXBatch, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapXbBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache, + state.batchStartPosHolder, + state.attnScaleBatch, state.ffnScaleBatch); + } + + // Per-layer weights: upload once + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + + // ── Attention Block ──────────────────────────────────────────────────── + layer.task("batch_attn_rms", + TransformerBatchPrefillKernels::batchedRmsReduce, + context, state.wrapXBatch, state.attnScaleBatch, + dim, config.rmsNormEps()); + + layer.task("batch_attn_rms_apply", + TransformerBatchPrefillKernels::batchedRmsApplyFP16, + context, state.wrapXbFP16Batch, state.wrapXBatch, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + state.attnScaleBatch, dim); + + layer.task("batch_qkv", + Qwen3Kernels::batchedFusedQKVMatmulFP16, + context, + state.wrapXbFP16Batch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + dim, qDim, kvDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_qk_rmsnorm", + Qwen3Kernels::batchedFusedQKRmsNorm, + context, + state.wrapQBatch, state.wrapKBatch, + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), + config.numberOfHeads(), nHeadKv, nEmbdHead, + qDim, kvDim, config.rmsNormEps()); + + layer.task("batch_rope_kv", + Qwen3Kernels::batchedRopeWithKVCacheQwen3, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapKeyCache, state.wrapValueCache, + kvDim, nEmbdHead, layerIndex, config.contextLength(), qDim); + + // Reuses batchedFlashAttention: passes qDim as the 'dim' stride parameter. + // Valid because qDim == dim for all standard Qwen3 models (nEmbdHeadK = dim/nHeads). + layer.task("batch_attention", + TransformerBatchPrefillKernels::batchedFlashAttention, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKeyCache, state.wrapValueCache, + state.wrapXbBatch, + config.numberOfHeads(), nEmbdHead, + kvDim, gqa, layerIndex, config.contextLength(), qDim); + + // Output projection: n=qDim (input), d=dim (output) + layer.task("batch_attn_out", + TransformerBatchPrefillKernels::batchedMatVecWithResidual, + context, state.wrapXbBatch, state.wrapXBatch, + weights.woLayered[layerIndex].asHalfFloatArray(), + qDim, dim, LOCAL_WORK_GROUP_SIZE); + + // ── FFN Block ────────────────────────────────────────────────────────── + layer.task("batch_ffn_rms", + TransformerBatchPrefillKernels::batchedFFNRmsReduce, + context, state.wrapXBatch, state.ffnScaleBatch, + dim, config.rmsNormEps()); + + layer.task("batch_ffn_gate_up", + TransformerBatchPrefillKernels::batchedFusedRmsNormFFNGateUp, + context, state.wrapXBatch, state.wrapHbBatch, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + state.ffnScaleBatch, + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), + dim, hidDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_ffn_down", + TransformerBatchPrefillKernels::batchedMatVecWithResidual, + context, state.wrapHbBatch, state.wrapXBatch, + weights.w2Layered[layerIndex].asHalfFloatArray(), + hidDim, dim, LOCAL_WORK_GROUP_SIZE); + + layer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache); + + return layer; + } + // @formatter:on + + public void updateGridScheduler(GridScheduler scheduler) { + int dim = config.dim(); + int hidDim = config.hiddenDim(); + + WorkerGrid rmsWorker = WorkerGridFactory.genericWorker(batchSize, 1); + WorkerGrid rmsApplyWorker = WorkerGridFactory.genericWorker(batchSize * dim, 256); + + int qkvRows = qDim + 2 * kvDim; + WorkerGrid qkvWorker = WorkerGridFactory.genericWorker( + batchSize * qkvRows * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + WorkerGrid qkRmsNormWorker = WorkerGridFactory.genericWorker( + batchSize * (config.numberOfHeads() + nHeadKv) * nEmbdHead, nEmbdHead); + + int ropeGlobal = batchSize * (qDim / 2); + int ropeLocal = Math.min(512, ropeGlobal); + while (ropeLocal > 1 && ropeGlobal % ropeLocal != 0) ropeLocal--; + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(ropeGlobal, ropeLocal); + + int optLocal = findOptimalLocalSize(nEmbdHead); + WorkerGrid attnWorker = WorkerGridFactory.genericWorker( + batchSize * config.numberOfHeads() * optLocal, optLocal); + + // Wo: B*dim output rows (n=qDim, d=dim) + WorkerGrid matVecDimWorker = WorkerGridFactory.genericWorker( + batchSize * dim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + WorkerGrid matVecHidWorker = WorkerGridFactory.genericWorker( + batchSize * hidDim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + for (int i = 0; i < config.numberOfLayers(); i++) { + String p = "batchPrefillLayer_" + i + "."; + scheduler.addWorkerGrid(p + "batch_attn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_attn_rms_apply", rmsApplyWorker); + scheduler.addWorkerGrid(p + "batch_qkv", qkvWorker); + scheduler.addWorkerGrid(p + "batch_qk_rmsnorm", qkRmsNormWorker); + scheduler.addWorkerGrid(p + "batch_rope_kv", ropeWorker); + scheduler.addWorkerGrid(p + "batch_attention", attnWorker); + scheduler.addWorkerGrid(p + "batch_attn_out", matVecDimWorker); + scheduler.addWorkerGrid(p + "batch_ffn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_ffn_gate_up", matVecHidWorker); + scheduler.addWorkerGrid(p + "batch_ffn_down", matVecDimWorker); + } + } + + private static int findOptimalLocalSize(int size) { + int optimal = Math.min(size, 64); + if (size % optimal != 0) { + for (int s = 64; s >= 1; s--) { + if (size % s == 0) { optimal = s; break; } + } + } + return optimal; + } + + public List getLayerImmutableTaskGraphs() { return layerITGs; } + public String getLastLayerTaskGraphID() { return lastLayerTaskGraphID; } + public KernelContext getContext() { return context; } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index 47584da6..5e3279a1 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -28,7 +28,7 @@ public class Qwen3Q8_0FFNLayers extends AbstractTransformerLayerTaskGraphs { // Typed reference to Qwen3-specific state - private final Qwen3State qwen3State; + protected final Qwen3State qwen3State; // Qwen3-specific GQA parameters private final int nHeadKv; @@ -110,7 +110,12 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { var unifiedLayer = new TaskGraph(taskGraphName); // === Data Setup === - unifiedLayer.consumeFromDevice(qwen3State.wrapX); + String wrapXSrc = predecessorGraphName(layerIndex); + if (wrapXSrc != null) { + unifiedLayer.consumeFromDevice(wrapXSrc, qwen3State.wrapX); + } else { + unifiedLayer.consumeFromDevice(qwen3State.wrapX); + } unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // Attention weights weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS norm weights @@ -275,12 +280,25 @@ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { config.dim(), // output dim LOCAL_WORK_GROUP_SIZE_ALLOC); - unifiedLayer.persistOnDevice(state.wrapX); + unifiedLayer.persistOnDevice(state.wrapX, state.wrapKeyCache, state.wrapValueCache); return unifiedLayer; } // @formatter:on + /** + * Returns the explicit predecessor graph name for consumeFromDevice. + * + *

The single-token plan relays all persisted buffers (including the KV cache) from a + * named predecessor graph: the activation graph for layer 0, the previous layer graph + * otherwise. The no-arg consume form does not propagate the persisted KV cache in + * interpreter mode, so it is re-allocated every decode token and exhausts the device + * memory pool (OOM) on long generations. Decode subclasses override this.

+ */ + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "activationUpdate" : "layer_" + (layerIndex - 1); + } + /** * Configure data transfers for first and subsequent layers */ @@ -304,8 +322,12 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye qwen3State.wrapAtt, qwen3State.wrapHb); } else { - // Subsequent layers: Consume data from previous layer - unifiedLayer.consumeFromDevice(context, + // Subsequent layers: consume from the previous layer graph BY NAME. The no-arg + // consume form does not propagate the persisted KV cache in interpreter mode, so + // it would be re-allocated every decode token and exhaust the memory pool (OOM). + String pred = "layer_" + (layerIndex - 1); + unifiedLayer.consumeFromDevice(pred, + context, qwen3State.wrapXb, qwen3State.wrapXb2, qwen3State.wrapQ, @@ -317,8 +339,7 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye qwen3State.wrapHb, qwen3State.positionHolder); - Qwen3State qwen3State = (Qwen3State) state; - unifiedLayer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur); // + unifiedLayer.consumeFromDevice(pred, qwen3State.tempQcur, qwen3State.tempKcur); // } return unifiedLayer; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersDecode.java new file mode 100644 index 00000000..bf77c38d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersDecode.java @@ -0,0 +1,59 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Decode transformer-layer TaskGraphs for the unified batched prefill-decode plan (Qwen3 Q8_0). + * + *

Layer 0: KV cache consumed from "decodeActivation" (already allocated by batch prefill). + * Layers 1+: all consumed objects use explicit predecessor name for interpreter mode.

+ */ +public class Qwen3Q8_0FFNLayersDecode extends Qwen3Q8_0FFNLayers { + + public Qwen3Q8_0FFNLayersDecode(String taskGraph, Qwen3State state, + Qwen3TornadoWeights weights, Qwen3Configuration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + qwen3State.positionHolder, qwen3State.temp, qwen3State.tempFFN); + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + qwen3State.tempQcur, qwen3State.tempKcur); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + qwen3State.wrapXb, qwen3State.wrapXb2, + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, + qwen3State.wrapAtt, qwen3State.wrapHb); + // KV cache already allocated by batch prefill; relay from decode activation graph. + layer.consumeFromDevice("decodeActivation", + qwen3State.wrapKeyCache, qwen3State.wrapValueCache); + } else { + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + qwen3State.wrapXb, qwen3State.wrapXb2, + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, + qwen3State.wrapAtt, qwen3State.wrapHb, + qwen3State.positionHolder, + qwen3State.temp, qwen3State.tempFFN); + layer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur); + } + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersPrefillDecode.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersPrefillDecode.java new file mode 100644 index 00000000..cfa9f0b8 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/decode/Qwen3Q8_0FFNLayersPrefillDecode.java @@ -0,0 +1,45 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; +import uk.ac.manchester.tornado.api.TaskGraph; + +/** + * Decode transformer-layer TaskGraphs for the single-token prefill/decode plan (Qwen3 Q8_0). + * + *

Layer 0 delegates to the base-class which allocates wrapKeyCache/wrapValueCache with + * FIRST_EXECUTION. Layers 1+ consume all live buffers from the explicit predecessor graph.

+ */ +public class Qwen3Q8_0FFNLayersPrefillDecode extends Qwen3Q8_0FFNLayers { + + public Qwen3Q8_0FFNLayersPrefillDecode(String taskGraph, Qwen3State state, + Qwen3TornadoWeights weights, Qwen3Configuration config, + SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + } + + @Override + protected String predecessorGraphName(int layerIndex) { + return (layerIndex == 0) ? "decodeActivation" : "layer_" + (layerIndex - 1); + } + + @Override + protected TaskGraph configureLayerDataTransfers(TaskGraph layer, int layerIndex) { + if (layerIndex == 0) { + return super.configureLayerDataTransfers(layer, 0); + } + String pred = "layer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + qwen3State.wrapXb, qwen3State.wrapXb2, + qwen3State.wrapQ, qwen3State.wrapK, qwen3State.wrapV, + qwen3State.wrapKeyCache, qwen3State.wrapValueCache, + qwen3State.wrapAtt, qwen3State.wrapHb, + qwen3State.positionHolder); + layer.consumeFromDevice(qwen3State.tempQcur, qwen3State.tempKcur); + return layer; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/Qwen3Q8_0LayersBatchPrefill.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/Qwen3Q8_0LayersBatchPrefill.java new file mode 100644 index 00000000..b3db3f41 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/prefill/Qwen3Q8_0LayersBatchPrefill.java @@ -0,0 +1,250 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill; + +import org.beehive.gpullama3.inference.state.Qwen3State; +import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Qwen3Kernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; +import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +/** + * Batched-prefill transformer-layer TaskGraphs for the Qwen3 Q8_0 unified batched prefill-decode plan. + * + *

Q8_0 path: wrapXbBatch (FP32) holds normalized activations; wrapXbFP16Batch is not used. + * Mirrors {@link Qwen3FP16LayersBatchPrefill} but uses Q8_0 weights (ByteArray) and FP32 + * attention normalization path.

+ */ +public class Qwen3Q8_0LayersBatchPrefill implements BatchPrefillTransformerLayerTaskGraphs { + + static final int LOCAL_WORK_GROUP_SIZE = 32; + + private final Qwen3State state; + private final Qwen3TornadoWeights weights; + private final Qwen3Configuration config; + private final KernelContext context = new KernelContext(); + private final int batchSize; + private final int nHeadKv; + private final int nEmbdHeadK; + private final int nEmbdHeadV; + private final int nEmbdHead; + private final int qDim; + private final int kvDim; + private final int gqa; + private final List layerITGs; + private String lastLayerTaskGraphID; + + public Qwen3Q8_0LayersBatchPrefill(Qwen3State state, Qwen3TornadoWeights weights, + Qwen3Configuration config, int batchSize) { + this.state = state; + this.weights = weights; + this.config = config; + this.batchSize = batchSize; + this.nHeadKv = config.numberOfKeyValueHeads(); + this.nEmbdHeadK = config.numberOfHeadsKey(); + this.nEmbdHeadV = config.numberOfHeadsValue(); + this.nEmbdHead = nEmbdHeadV; + this.qDim = nEmbdHeadK * config.numberOfHeads(); + this.kvDim = nEmbdHeadV * nHeadKv; + this.gqa = config.numberOfHeads() / nHeadKv; + this.layerITGs = IntStream.range(0, config.numberOfLayers()) + .mapToObj(this::createBatchPrefillLayerTaskGraph) + .map(TaskGraph::snapshot) + .toList(); + } + + // @formatter:off + private TaskGraph createBatchPrefillLayerTaskGraph(int layerIndex) { + String graphName = "batchPrefillLayer_" + layerIndex; + if (layerIndex == config.numberOfLayers() - 1) lastLayerTaskGraphID = graphName; + + TaskGraph layer = new TaskGraph(graphName); + int dim = config.dim(); + int hidDim = config.hiddenDim(); + + // ── Data Transfers ───────────────────────────────────────────────────── + if (layerIndex == 0) { + layer.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.batchStartPosHolder); + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.attnScaleBatch, state.ffnScaleBatch, + state.wrapXbBatch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache); + layer.consumeFromDevice("prefillActivation", state.wrapXBatch); + } else { + String pred = "batchPrefillLayer_" + (layerIndex - 1); + layer.consumeFromDevice(pred, + context, + state.wrapXBatch, + state.wrapXbBatch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapHbBatch, + state.wrapKeyCache, state.wrapValueCache, + state.batchStartPosHolder, + state.attnScaleBatch, state.ffnScaleBatch); + } + + // Per-layer weights (Q8_0 format) + layer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + weights.woLayered[layerIndex].asByteArray(), + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w2Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray()); + + // ── Attention Block ──────────────────────────────────────────────────── + layer.task("batch_attn_rms", + TransformerBatchPrefillKernels::batchedRmsReduce, + context, state.wrapXBatch, state.attnScaleBatch, + dim, config.rmsNormEps()); + + // FP32 normalize into wrapXbBatch (Q8_0 path: no FP16 quantize step) + layer.task("batch_attn_rms_apply", + TransformerBatchPrefillKernels::batchedRmsApplyFP32, + context, state.wrapXbBatch, state.wrapXBatch, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + state.attnScaleBatch, dim); + + layer.task("batch_qkv", + Qwen3Kernels::batchedFusedQKVMatmulQ8_0, + context, + state.wrapXbBatch, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + dim, qDim, kvDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_qk_rmsnorm", + Qwen3Kernels::batchedFusedQKRmsNorm, + context, + state.wrapQBatch, state.wrapKBatch, + weights.rms_att_QNormLayered[layerIndex].asFloatArray(), + weights.rms_att_KNormLayered[layerIndex].asFloatArray(), + config.numberOfHeads(), nHeadKv, nEmbdHead, + qDim, kvDim, config.rmsNormEps()); + + layer.task("batch_rope_kv", + Qwen3Kernels::batchedRopeWithKVCacheQwen3, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKBatch, state.wrapVBatch, + state.wrapKeyCache, state.wrapValueCache, + kvDim, nEmbdHead, layerIndex, config.contextLength(), qDim); + + // Reuses batchedFlashAttention; passes qDim as the 'dim' stride (valid: qDim==dim typically). + layer.task("batch_attention", + TransformerBatchPrefillKernels::batchedFlashAttention, + context, state.batchStartPosHolder, + state.wrapQBatch, state.wrapKeyCache, state.wrapValueCache, + state.wrapXbBatch, + config.numberOfHeads(), nEmbdHead, + kvDim, gqa, layerIndex, config.contextLength(), qDim); + + // Output projection (Q8_0): n=qDim, d=dim + layer.task("batch_attn_out", + TransformerBatchPrefillKernels::batchedMatVecWithResidualQ8, + context, state.wrapXbBatch, state.wrapXBatch, + weights.woLayered[layerIndex].asByteArray(), + qDim, dim, LOCAL_WORK_GROUP_SIZE); + + // ── FFN Block ────────────────────────────────────────────────────────── + layer.task("batch_ffn_rms", + TransformerBatchPrefillKernels::batchedFFNRmsReduce, + context, state.wrapXBatch, state.ffnScaleBatch, + dim, config.rmsNormEps()); + + layer.task("batch_ffn_gate_up", + TransformerBatchPrefillKernels::batchedFusedRmsNormFFNGateUpQ8, + context, state.wrapXBatch, state.wrapHbBatch, + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + state.ffnScaleBatch, + weights.w1Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray(), + dim, hidDim, LOCAL_WORK_GROUP_SIZE); + + layer.task("batch_ffn_down", + TransformerBatchPrefillKernels::batchedMatVecWithResidualQ8, + context, state.wrapHbBatch, state.wrapXBatch, + weights.w2Layered[layerIndex].asByteArray(), + hidDim, dim, LOCAL_WORK_GROUP_SIZE); + + layer.persistOnDevice(state.wrapXBatch, state.wrapKeyCache, state.wrapValueCache); + + return layer; + } + // @formatter:on + + public void updateGridScheduler(GridScheduler scheduler) { + int dim = config.dim(); + int hidDim = config.hiddenDim(); + + WorkerGrid rmsWorker = WorkerGridFactory.genericWorker(batchSize, 1); + WorkerGrid rmsApplyWorker = WorkerGridFactory.genericWorker(batchSize * dim, 256); + + int qkvRows = qDim + 2 * kvDim; + WorkerGrid qkvWorker = WorkerGridFactory.genericWorker( + batchSize * qkvRows * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + WorkerGrid qkRmsNormWorker = WorkerGridFactory.genericWorker( + batchSize * (config.numberOfHeads() + nHeadKv) * nEmbdHead, nEmbdHead); + + int ropeGlobal = batchSize * (qDim / 2); + int ropeLocal = Math.min(512, ropeGlobal); + while (ropeLocal > 1 && ropeGlobal % ropeLocal != 0) ropeLocal--; + WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(ropeGlobal, ropeLocal); + + int optLocal = findOptimalLocalSize(nEmbdHead); + WorkerGrid attnWorker = WorkerGridFactory.genericWorker( + batchSize * config.numberOfHeads() * optLocal, optLocal); + + WorkerGrid matVecDimWorker = WorkerGridFactory.genericWorker( + batchSize * dim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + WorkerGrid matVecHidWorker = WorkerGridFactory.genericWorker( + batchSize * hidDim * LOCAL_WORK_GROUP_SIZE, LOCAL_WORK_GROUP_SIZE); + + for (int i = 0; i < config.numberOfLayers(); i++) { + String p = "batchPrefillLayer_" + i + "."; + scheduler.addWorkerGrid(p + "batch_attn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_attn_rms_apply", rmsApplyWorker); + scheduler.addWorkerGrid(p + "batch_qkv", qkvWorker); + scheduler.addWorkerGrid(p + "batch_qk_rmsnorm", qkRmsNormWorker); + scheduler.addWorkerGrid(p + "batch_rope_kv", ropeWorker); + scheduler.addWorkerGrid(p + "batch_attention", attnWorker); + scheduler.addWorkerGrid(p + "batch_attn_out", matVecDimWorker); + scheduler.addWorkerGrid(p + "batch_ffn_rms", rmsWorker); + scheduler.addWorkerGrid(p + "batch_ffn_gate_up", matVecHidWorker); + scheduler.addWorkerGrid(p + "batch_ffn_down", matVecDimWorker); + } + } + + private static int findOptimalLocalSize(int size) { + int optimal = Math.min(size, 64); + if (size % optimal != 0) { + for (int s = 64; s >= 1; s--) { + if (size % s == 0) { optimal = s; break; } + } + } + return optimal; + } + + public List getLayerImmutableTaskGraphs() { return layerITGs; } + public String getLastLayerTaskGraphID() { return lastLayerTaskGraphID; } + public KernelContext getContext() { return context; } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java index 504eb98f..042dd354 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/ForwardPlanFactory.java @@ -176,15 +176,21 @@ private static ForwardPlan createQwen2Q8_0Plan(ExecutionMode mode, Qwen2State st } private static ForwardPlan createQwen3FP16Plan(ExecutionMode mode, Qwen3State state, Model model) { - if (mode != ExecutionMode.STANDARD) - throw new UnsupportedOperationException(mode + " not yet supported for QWEN_3 + F16"); - return new SingleTokenForwardPlan(model, new Qwen3FP16PlanComponents(state, model)); + BatchPrefillDecodeForwardPlanComponents components = new Qwen3FP16PlanComponents(state, model); + return switch (mode) { + case STANDARD -> new SingleTokenForwardPlan(model, components); + case PREFILL_DECODE -> new PrefillDecodeForwardPlan(model, components); + case BATCH_PREFILL_DECODE -> new BatchPrefillDecodeForwardPlan(model, components, TornadoVMMasterPlan.PREFILL_BATCH_SIZE); + }; } private static ForwardPlan createQwen3Q8_0Plan(ExecutionMode mode, Qwen3State state, Model model) { - if (mode != ExecutionMode.STANDARD) - throw new UnsupportedOperationException(mode + " not yet supported for QWEN_3 + Q8_0"); - return new SingleTokenForwardPlan(model, new Qwen3Q8_0PlanComponents(state, model)); + BatchPrefillDecodeForwardPlanComponents components = new Qwen3Q8_0PlanComponents(state, model); + return switch (mode) { + case STANDARD -> new SingleTokenForwardPlan(model, components); + case PREFILL_DECODE -> new PrefillDecodeForwardPlan(model, components); + case BATCH_PREFILL_DECODE -> new BatchPrefillDecodeForwardPlan(model, components, TornadoVMMasterPlan.PREFILL_BATCH_SIZE); + }; } private static ForwardPlan createPhi3FP16Plan(ExecutionMode mode, Phi3State state, Model model) { diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java index 5b63530d..ea9fb1fe 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchDecodeActivation.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.tornadovm.plan.components.activation; -import org.beehive.gpullama3.inference.state.LlamaState; -import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; import org.beehive.gpullama3.tornadovm.scheduling.WorkerGridFactory; @@ -26,14 +26,14 @@ public class BatchDecodeActivation implements ActivationTaskGraph { private final ImmutableTaskGraph itg; private final int dim; - public BatchDecodeActivation(LlamaState state, LlamaConfiguration config, String lastBatchLayerId, boolean isQ8) { + public BatchDecodeActivation(State state, Configuration config, String lastBatchLayerId, boolean isQ8) { this.dim = config.dim(); KernelContext ctx = new KernelContext(); this.itg = buildGraph(ctx, state, lastBatchLayerId, isQ8).snapshot(); } // @formatter:off - private TaskGraph buildGraph(KernelContext ctx, LlamaState state, + private TaskGraph buildGraph(KernelContext ctx, State state, String lastBatchLayerId, boolean isQ8) { TaskGraph tg = new TaskGraph("decodeActivation") .consumeFromDevice(lastBatchLayerId, state.wrapKeyCache, state.wrapValueCache) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java index 17b80cfa..aa1b02d7 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/activation/BatchPrefillActivation.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.tornadovm.plan.components.activation; -import org.beehive.gpullama3.inference.state.LlamaState; -import org.beehive.gpullama3.model.llama.LlamaConfiguration; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.tornadovm.kernels.TransformerBatchPrefillKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; @@ -29,7 +29,7 @@ public class BatchPrefillActivation implements ActivationTaskGraph { private final int batchSize; private final int dim; - public BatchPrefillActivation(LlamaState state, LlamaConfiguration config, int batchSize, boolean isQ8) { + public BatchPrefillActivation(State state, Configuration config, int batchSize, boolean isQ8) { this.isQ8 = isQ8; this.batchSize = batchSize; this.dim = config.dim(); @@ -37,7 +37,7 @@ public BatchPrefillActivation(LlamaState state, LlamaConfiguration config, int b this.itg = buildGraph(ctx, state).snapshot(); } - private TaskGraph buildGraph(KernelContext ctx, LlamaState state) { + private TaskGraph buildGraph(KernelContext ctx, State state) { if (isQ8) { return new TaskGraph("prefillActivation") .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapXBatch) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java index 5b2fe286..4cb7a0aa 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/fp16/Qwen3FP16PlanComponents.java @@ -7,14 +7,21 @@ import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.Qwen3FP16FFNLayers; -import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.LogitsFP16LayerDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.Qwen3FP16FFNLayersDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.decode.Qwen3FP16FFNLayersPrefillDecode; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.prefill.Qwen3FP16LayersBatchPrefill; +import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchDecodeActivation; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchPrefillActivation; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -public class Qwen3FP16PlanComponents implements SingleTokenForwardPlanComponents { +public class Qwen3FP16PlanComponents implements BatchPrefillDecodeForwardPlanComponents { private final Qwen3State state; private final Qwen3TornadoWeights weights; @@ -28,18 +35,59 @@ public Qwen3FP16PlanComponents(Qwen3State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } + // ── Activations ─────────────────────────────────────────────────────────── + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } + @Override + public ActivationTaskGraph prefillDecodeActivation() { + return new Activation("decodeActivation", state, weights, config); + } + + @Override + public ActivationTaskGraph batchPrefillActivation(int batchSize) { + return new BatchPrefillActivation(state, config, batchSize, false); + } + + @Override + public ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId) { + return new BatchDecodeActivation(state, config, lastBatchLayerId, false); + } + + // ── Transformer layer TaskGraphs ────────────────────────────────────────── + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new Qwen3FP16FFNLayers("qwen3FFN", state, weights, config, schedulerType); } + @Override + public TransformerLayerTaskGraphs prefillDecodeTransformerLayers() { + return new Qwen3FP16FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + } + + @Override + public TransformerLayerTaskGraphs batchDecodeTransformerLayers() { + return new Qwen3FP16FFNLayersDecode("decode", state, weights, config, schedulerType); + } + + @Override + public BatchPrefillTransformerLayerTaskGraphs batchPrefillTransformerLayers(int batchSize) { + return new Qwen3FP16LayersBatchPrefill(state, weights, config, batchSize); + } + + // ── Logits layers ───────────────────────────────────────────────────────── + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsFP16Layer("logits", state, weights, config, previousGraphId, schedulerType); } + + @Override + public AbstractLogitsTaskGraph decodeLogits(String previousGraphId) { + return new LogitsFP16LayerDecode("logits", state, weights, config, previousGraphId, schedulerType); + } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java index 024c0364..6f067d1f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/plan/components/q8_0/Qwen3Q8_0PlanComponents.java @@ -7,14 +7,21 @@ import org.beehive.gpullama3.tornadovm.layers.AbstractLogitsTaskGraph; import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationTaskGraph; +import org.beehive.gpullama3.tornadovm.layers.BatchPrefillTransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.TransformerLayerTaskGraphs; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsQ8_0Layer; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.Qwen3Q8_0FFNLayers; -import org.beehive.gpullama3.tornadovm.plan.components.SingleTokenForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.LogitsQ8_0LayerDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.Qwen3Q8_0FFNLayersDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.decode.Qwen3Q8_0FFNLayersPrefillDecode; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.prefill.Qwen3Q8_0LayersBatchPrefill; +import org.beehive.gpullama3.tornadovm.plan.components.BatchPrefillDecodeForwardPlanComponents; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchDecodeActivation; +import org.beehive.gpullama3.tornadovm.plan.components.activation.BatchPrefillActivation; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerDetectionService; import org.beehive.gpullama3.tornadovm.scheduling.SchedulerType; -public class Qwen3Q8_0PlanComponents implements SingleTokenForwardPlanComponents { +public class Qwen3Q8_0PlanComponents implements BatchPrefillDecodeForwardPlanComponents { private final Qwen3State state; private final Qwen3TornadoWeights weights; @@ -28,18 +35,59 @@ public Qwen3Q8_0PlanComponents(Qwen3State state, Model model) { this.schedulerType = SchedulerDetectionService.determineSchedulerType(model); } + // ── Activations ─────────────────────────────────────────────────────────── + @Override public ActivationTaskGraph singleTokenActivation() { return new Activation("activationUpdate", state, weights, config); } + @Override + public ActivationTaskGraph prefillDecodeActivation() { + return new Activation("decodeActivation", state, weights, config); + } + + @Override + public ActivationTaskGraph batchPrefillActivation(int batchSize) { + return new BatchPrefillActivation(state, config, batchSize, true); + } + + @Override + public ActivationTaskGraph batchDecodeActivation(String lastBatchLayerId) { + return new BatchDecodeActivation(state, config, lastBatchLayerId, true); + } + + // ── Transformer layer TaskGraphs ────────────────────────────────────────── + @Override public TransformerLayerTaskGraphs singleTokenTransformerLayers() { return new Qwen3Q8_0FFNLayers("qwen3FFN", state, weights, config, schedulerType); } + @Override + public TransformerLayerTaskGraphs prefillDecodeTransformerLayers() { + return new Qwen3Q8_0FFNLayersPrefillDecode("decode", state, weights, config, schedulerType); + } + + @Override + public TransformerLayerTaskGraphs batchDecodeTransformerLayers() { + return new Qwen3Q8_0FFNLayersDecode("decode", state, weights, config, schedulerType); + } + + @Override + public BatchPrefillTransformerLayerTaskGraphs batchPrefillTransformerLayers(int batchSize) { + return new Qwen3Q8_0LayersBatchPrefill(state, weights, config, batchSize); + } + + // ── Logits layers ───────────────────────────────────────────────────────── + @Override public AbstractLogitsTaskGraph singleTokenLogits(String previousGraphId) { return new LogitsQ8_0Layer("logits", state, weights, config, previousGraphId, schedulerType); } + + @Override + public AbstractLogitsTaskGraph decodeLogits(String previousGraphId) { + return new LogitsQ8_0LayerDecode("logits", state, weights, config, previousGraphId, schedulerType); + } } diff --git a/src/test/java/org/beehive/gpullama3/model/format/ToolCallParserUtilsTest.java b/src/test/java/org/beehive/gpullama3/model/format/ToolCallParserUtilsTest.java new file mode 100644 index 00000000..74e74ea6 --- /dev/null +++ b/src/test/java/org/beehive/gpullama3/model/format/ToolCallParserUtilsTest.java @@ -0,0 +1,166 @@ +package org.beehive.gpullama3.model.format; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.util.List; +import java.util.Optional; + +import org.junit.Test; + +/** + * Unit tests for {@link ToolCallParserUtils}. + * + *

The parser is pure string-handling (no model/tokenizer), so every recognised response + * shape is covered here: Qwen3/Llama {@code } tags, Llama 3.1 {@code <|python_tag|>}, + * raw-JSON and markdown-fence fallbacks, unclosed tags, and batch (multi-call) responses. + * The brace-in-string cases pin the fix that keeps argument objects whose string values + * contain {@code {}/{}} characters (e.g. source code) intact. + */ +public class ToolCallParserUtilsTest { + + // ── Single-call extraction ──────────────────────────────────────────────── + + @Test + public void qwen3ToolCall_arguments() { + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "\n{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Chania\"}}\n"); + assertTrue(tc.isPresent()); + assertEquals("get_weather", tc.get().name()); + assertEquals("{\"city\": \"Chania\"}", tc.get().argumentsJson()); + } + + @Test + public void llama31_pythonTag_parametersKey() { + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "<|python_tag|>{\"name\": \"get_weather\", \"parameters\": {\"city\": \"Boston\"}}"); + assertTrue(tc.isPresent()); + assertEquals("get_weather", tc.get().name()); + assertEquals("{\"city\": \"Boston\"}", tc.get().argumentsJson()); + } + + @Test + public void functionKey_usedAsNameFallback() { + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"function\": \"list_dir\", \"arguments\": {\"path\": \"/tmp\"}}"); + assertTrue(tc.isPresent()); + assertEquals("list_dir", tc.get().name()); + } + + @Test + public void missingArguments_defaultsToEmptyObject() { + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"name\": \"now\"}"); + assertTrue(tc.isPresent()); + assertEquals("now", tc.get().name()); + assertEquals("{}", tc.get().argumentsJson()); + } + + @Test + public void unclosedToolCall_stillParsed() { + // Model stopped (eot/eom) before emitting the closing tag. + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"name\": \"ping\", \"arguments\": {\"host\": \"a\"}}"); + assertTrue(tc.isPresent()); + assertEquals("ping", tc.get().name()); + assertEquals("{\"host\": \"a\"}", tc.get().argumentsJson()); + } + + @Test + public void plainTextResponse_isNotAToolCall() { + assertFalse(ToolCallParserUtils.parseToolCallResponse("The weather in Chania is sunny.").isPresent()); + } + + // ── Brace-in-string argument objects (the core fix) ─────────────────────── + + @Test + public void argumentsWithBracesInStringValue_keptIntact() { + String args = "{\"code\": \"public class A { void m() { return; } }\"}"; + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"name\": \"write_file\", \"arguments\": " + args + "}"); + assertTrue(tc.isPresent()); + assertEquals("write_file", tc.get().name()); + assertEquals(args, tc.get().argumentsJson()); + } + + @Test + public void argumentsWithEscapedQuotesAndBraces_keptIntact() { + String args = "{\"snippet\": \"if (s.equals(\\\"}\\\")) { x++; }\"}"; + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"name\": \"run\", \"arguments\": " + args + "}"); + assertTrue(tc.isPresent()); + assertEquals(args, tc.get().argumentsJson()); + } + + @Test + public void argumentsWithNestedObjectsAndArrays_keptIntact() { + String args = "{\"items\": [{\"a\": 1}, {\"b\": 2}], \"meta\": {\"n\": 3}}"; + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"name\": \"batch\", \"arguments\": " + args + "}"); + assertTrue(tc.isPresent()); + assertEquals(args, tc.get().argumentsJson()); + } + + // ── Fallbacks ───────────────────────────────────────────────────────────── + + @Test + public void rawJsonFallback_noTags() { + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "{\"name\": \"echo\", \"arguments\": {\"msg\": \"hi\"}}"); + assertTrue(tc.isPresent()); + assertEquals("echo", tc.get().name()); + } + + @Test + public void markdownFencedJson_fallback() { + Optional tc = ToolCallParserUtils.parseToolCallResponse( + "```json\n{\"name\": \"echo\", \"arguments\": {\"msg\": \"hi\"}}\n```"); + assertTrue(tc.isPresent()); + assertEquals("echo", tc.get().name()); + assertEquals("{\"msg\": \"hi\"}", tc.get().argumentsJson()); + } + + @Test + public void stripMarkdownFences_removesFenceLines() { + assertEquals("body", ToolCallParserUtils.stripMarkdownFences("```\nbody\n```")); + assertEquals("plain", ToolCallParserUtils.stripMarkdownFences("plain")); + } + + // ── Batch (multiple tool calls) ─────────────────────────────────────────── + + @Test + public void batch_multipleToolCallBlocks() { + List calls = ToolCallParserUtils.parseAllToolCalls( + "{\"name\": \"a\", \"arguments\": {\"x\": 1}}" + + "{\"name\": \"b\", \"arguments\": {\"y\": 2}}"); + assertEquals(2, calls.size()); + assertEquals("a", calls.get(0).name()); + assertEquals("{\"x\": 1}", calls.get(0).argumentsJson()); + assertEquals("b", calls.get(1).name()); + assertEquals("{\"y\": 2}", calls.get(1).argumentsJson()); + } + + @Test + public void batch_bracesInStringDoNotBleedAcrossCalls() { + List calls = ToolCallParserUtils.parseAllToolCalls( + "{\"name\": \"write\", \"arguments\": {\"code\": \"a { b }\"}}" + + "{\"name\": \"log\", \"arguments\": {\"msg\": \"ok\"}}"); + assertEquals(2, calls.size()); + assertEquals("{\"code\": \"a { b }\"}", calls.get(0).argumentsJson()); + assertEquals("log", calls.get(1).name()); + } + + @Test + public void batch_pythonTagIsSingleCall() { + List calls = ToolCallParserUtils.parseAllToolCalls( + "<|python_tag|>{\"name\": \"a\", \"parameters\": {\"x\": 1}}"); + assertEquals(1, calls.size()); + assertEquals("a", calls.get(0).name()); + } + + @Test + public void batch_noToolCalls_returnsEmpty() { + assertTrue(ToolCallParserUtils.parseAllToolCalls("just a plain answer").isEmpty()); + } +}