diff --git a/kotlin-js-store/wasm/yarn.lock b/kotlin-js-store/wasm/yarn.lock index fd812f32..b990b1c9 100644 --- a/kotlin-js-store/wasm/yarn.lock +++ b/kotlin-js-store/wasm/yarn.lock @@ -2,7 +2,7 @@ # yarn lockfile v1 -ws@8.18.3: - version "8.18.3" - resolved "https://registry.yarnpkg.com/ws/-/ws-8.18.3.tgz#b56b88abffde62791c639170400c93dcb0c95472" - integrity sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg== +ws@8.20.1: + version "8.20.1" + resolved "https://registry.yarnpkg.com/ws/-/ws-8.20.1.tgz#91a9ae2b312ccf98e0a85ec499b48cef45ab0ddb" + integrity sha512-It4dO0K5v//JtTXuPkfEOaI3uUN87iYPnqo/ZzqCoG3g8uhA66QUMs/SrM0YK7/NAu+r4LMh/9dq2A7k+rHs+w== diff --git a/skainet-backends/skainet-backend-native-cpu/native/CMakeLists.txt b/skainet-backends/skainet-backend-native-cpu/native/CMakeLists.txt index 492ac4b7..cc42e955 100644 --- a/skainet-backends/skainet-backend-native-cpu/native/CMakeLists.txt +++ b/skainet-backends/skainet-backend-native-cpu/native/CMakeLists.txt @@ -13,6 +13,7 @@ set(SKAINET_KERNEL_SOURCES src/skainet_smoke.c src/q4k_matmul.c src/q5k_matmul.c + src/q6k_matmul.c src/fp32_matmul.c src/bf16_matmul.c src/q8_0_matmul.c diff --git a/skainet-backends/skainet-backend-native-cpu/native/include/skainet_kernels.h b/skainet-backends/skainet-backend-native-cpu/native/include/skainet_kernels.h index 167fc80b..6d58e7eb 100644 --- a/skainet-backends/skainet-backend-native-cpu/native/include/skainet_kernels.h +++ b/skainet-backends/skainet-backend-native-cpu/native/include/skainet_kernels.h @@ -85,6 +85,33 @@ SKAINET_API void skainet_q5k_matmul( int32_t output_offset ); +/* + * Q6_K matrix-vector multiply. + * + * output[output_offset + o] = sum_j input[input_offset + j] * + * dequant(weight[block, o, j]) + * + * Block layout: canonical ggml Q6_K, 256 elements per super-block, 210 + * bytes per block (128 B `ql` low nibbles + 64 B `qh` high-2-bit plane + + * 16 B int8 `scales` + 2 B `d` FP16). Each 6-bit code is + * `lowNibble | (highBits << 4)`, dequantized as `d * scale * (code - 32)` + * (signed, range [-32, 31]; Q6_K has no per-block min). Packed weights + * laid out as + * weight + weight_byte_offset + (block_idx * output_dim + o) * 210 + * + * input_dim must be a multiple of 256. + */ +SKAINET_API void skainet_q6k_matmul( + const float* input, + int32_t input_offset, + const uint8_t* weight, + int32_t weight_byte_offset, + int32_t input_dim, + int32_t output_dim, + float* output, + int32_t output_offset +); + /* * Row-major FP32 SGEMM: C(m, n) = A(m, k) * B(k, n). * diff --git a/skainet-backends/skainet-backend-native-cpu/native/src/q6k_matmul.c b/skainet-backends/skainet-backend-native-cpu/native/src/q6k_matmul.c new file mode 100644 index 00000000..d4e5ab11 --- /dev/null +++ b/skainet-backends/skainet-backend-native-cpu/native/src/q6k_matmul.c @@ -0,0 +1,146 @@ +#include "skainet_kernels.h" +#include "skainet_simd.h" + +#include +#include + +#define Q6K_BLOCK_SIZE 256 +#define Q6K_BYTES_PER_BLOCK 210 +#define Q6K_QL_OFFSET 0 +#define Q6K_QH_OFFSET 128 +#define Q6K_SCALES_OFFSET 192 +#define Q6K_D_OFFSET 208 + +/* + * IEEE 754 binary16 (LE byte order) -> binary32 conversion. + * Byte-for-byte identical to the Q5_K / Q4_K converter (kept scalar to + * preserve bit-exact FP16 parity with the Panama / scalar references). + */ +static inline float skainet_q6k_half_to_float(uint16_t hbits) { + const uint32_t sign = (hbits >> 15) & 0x1u; + const uint32_t exp = (hbits >> 10) & 0x1Fu; + const uint32_t frac = hbits & 0x3FFu; + + if (exp == 0u) { + if (frac == 0u) { + union { uint32_t u; float f; } v = { sign << 31 }; + return v.f; + } + float f = ((float) frac) / 1024.0f * (1.0f / 16384.0f); + return sign ? -f : f; + } + if (exp == 0x1Fu) { + union { uint32_t u; float f; } v; + v.u = (sign << 31) | 0x7F800000u | (frac ? 0x00400000u : 0u); + return v.f; + } + union { uint32_t u; float f; } v; + v.u = (sign << 31) | ((exp - 15u + 127u) << 23) | (frac << 13); + return v.f; +} + +/* + * Dequantize one 256-element Q6_K super-block into scratch[256]. + * Direct transcription of ScalarQ6_KMatmulKernel.dequantBlock / + * ggml dequantize_row_q6_K: two 128-element halves, each split into two + * 16-element scale groups carrying four strided sub-codes (q1..q4). + * + * The 6-bit code is `lowNibble(ql) | (twoHighBits(qh) << 4)`, biased by + * -32, and `scales` are SIGNED int8. Per-element value = d * scale * code. + */ +static inline void skainet_q6k_dequant_block(const uint8_t* SKAINET_RESTRICT block, + float* SKAINET_RESTRICT scratch) { + const uint8_t* ql0 = block + Q6K_QL_OFFSET; + const uint8_t* qh0 = block + Q6K_QH_OFFSET; + const int8_t* sc0 = (const int8_t*)(block + Q6K_SCALES_OFFSET); + const uint16_t d_bits = (uint16_t) block[Q6K_D_OFFSET] + | ((uint16_t) block[Q6K_D_OFFSET + 1] << 8); + const float d = skainet_q6k_half_to_float(d_bits); + + for (int half = 0; half < 2; ++half) { + const uint8_t* ql = ql0 + half * 64; + const uint8_t* qh = qh0 + half * 32; + const int8_t* sc = sc0 + half * 8; + float* out = scratch + half * 128; + for (int is = 0; is < 2; ++is) { + const float sc1 = d * (float) sc[is + 0]; + const float sc2 = d * (float) sc[is + 2]; + const float sc3 = d * (float) sc[is + 4]; + const float sc4 = d * (float) sc[is + 6]; + const int l_start = is * 16; + for (int l = l_start; l < l_start + 16; ++l) { + const int q_l0 = ql[l]; + const int q_l32 = ql[l + 32]; + const int q_h = qh[l]; + const int q1 = ((q_l0 & 0x0F) | ((q_h & 0x03) << 4)) - 32; + const int q2 = ((q_l32 & 0x0F) | (((q_h >> 2) & 0x03) << 4)) - 32; + const int q3 = ((q_l0 >> 4) | (((q_h >> 4) & 0x03) << 4)) - 32; + const int q4 = ((q_l32 >> 4) | (((q_h >> 6) & 0x03) << 4)) - 32; + out[l + 0] = sc1 * (float) q1; + out[l + 32] = sc2 * (float) q2; + out[l + 64] = sc3 * (float) q3; + out[l + 96] = sc4 * (float) q4; + } + } + } +} + +/* + * Native Q6_K matrix-vector multiply matching the + * sk.ainet.backend.api.kernel.Q6KMatmulKernel SPI contract. A single + * input row times an `outputDim x inputDim` Q6_K-packed weight tensor + * laid out (blockIdx * outputDim + o) * 210 bytes. + * + * The 6-bit bit-assembly is kept scalar (cheap byte shuffling that the + * compiler auto-vectorizes under -O3) and materialized into a 256-float + * scratch block; the hot dot product against the input window is the + * NEON path (vfmaq_f32 + horizontal add) behind __ARM_NEON. On non-ARM + * targets the dot is a straight-line loop that auto-vectorizes too. + */ +SKAINET_API void skainet_q6k_matmul( + const float* SKAINET_RESTRICT input, + int32_t input_offset, + const uint8_t* SKAINET_RESTRICT weight, + int32_t weight_byte_offset, + int32_t input_dim, + int32_t output_dim, + float* SKAINET_RESTRICT output, + int32_t output_offset +) { + if (output_dim <= 0 || input_dim <= 0) return; + + const int32_t blocks_per_input_dim = input_dim / Q6K_BLOCK_SIZE; + const float* in_base = input + input_offset; + float* out_base = output + output_offset; + + float scratch[Q6K_BLOCK_SIZE]; + + for (int32_t o = 0; o < output_dim; ++o) { + float acc = 0.0f; + + for (int32_t block_idx = 0; block_idx < blocks_per_input_dim; ++block_idx) { + const uint8_t* block = weight + weight_byte_offset + + (size_t)(block_idx * output_dim + o) * Q6K_BYTES_PER_BLOCK; + + skainet_q6k_dequant_block(block, scratch); + + const float* in_block = in_base + (size_t) block_idx * Q6K_BLOCK_SIZE; + +#ifdef SKAINET_HAVE_NEON + float32x4_t vacc = vdupq_n_f32(0.0f); + for (int i = 0; i < Q6K_BLOCK_SIZE; i += 4) { + const float32x4_t vi = vld1q_f32(in_block + i); + const float32x4_t vw = vld1q_f32(scratch + i); + vacc = vfmaq_f32(vacc, vi, vw); + } + acc += skainet_neon_hadd_f32(vacc); +#else + for (int i = 0; i < Q6K_BLOCK_SIZE; ++i) { + acc += in_block[i] * scratch[i]; + } +#endif + } + + out_base[o] = acc; + } +} diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt index ba0011b2..f52221f8 100644 --- a/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeKernelProvider.kt @@ -8,6 +8,7 @@ import sk.ainet.backend.api.kernel.Q4KMatmulKernel import sk.ainet.backend.api.kernel.Q4KMemSegMatmulKernel import sk.ainet.backend.api.kernel.Q4_0MatmulKernel import sk.ainet.backend.api.kernel.Q5KMatmulKernel +import sk.ainet.backend.api.kernel.Q6KMatmulKernel import sk.ainet.backend.api.kernel.Q8_0MatmulKernel /** @@ -73,7 +74,7 @@ import sk.ainet.backend.api.kernel.Q8_0MatmulKernel * - PR 2: real Q4_K matmul wired into the heap SPI. * - PR 3: MemSeg-input zero-copy sibling. * - PR 5: native FP32 matmul wired into [matmulFp32]. - * - Later: native `matmulQ6K`, `matmulQ8_0` (need new SPI accessors). + * - Now: native `matmulQ5K`, `matmulQ6K`, `matmulQ8_0`, `matmulQ4_0` all wired. */ public object NativeKernelProvider : KernelProvider, MemSegKernelProvider { override val name: String = "native-ffm" @@ -101,4 +102,7 @@ public object NativeKernelProvider : KernelProvider, MemSegKernelProvider { override fun matmulQ5K(): Q5KMatmulKernel? = if (NativeQ5KMatmulKernel.isAvailable()) NativeQ5KMatmulKernel else null + + override fun matmulQ6K(): Q6KMatmulKernel? = + if (NativeQ6KMatmulKernel.isAvailable()) NativeQ6KMatmulKernel else null } diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeQ6KMatmulKernel.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeQ6KMatmulKernel.kt new file mode 100644 index 00000000..382f57a7 --- /dev/null +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmMain/kotlin/sk/ainet/exec/kernel/NativeQ6KMatmulKernel.kt @@ -0,0 +1,91 @@ +package sk.ainet.exec.kernel + +import java.lang.foreign.Arena +import java.lang.foreign.FunctionDescriptor +import java.lang.foreign.Linker +import java.lang.foreign.MemorySegment +import java.lang.foreign.ValueLayout +import java.lang.invoke.MethodHandle +import sk.ainet.backend.api.kernel.Q6KMatmulKernel + +/** + * Native (FFM) implementation of [Q6KMatmulKernel]. + * + * Wraps the bundled C symbol + * + * void skainet_q6k_matmul( + * const float* input, int32_t input_offset, + * const uint8_t* weight, int32_t weight_byte_offset, + * int32_t input_dim, int32_t output_dim, + * float* output, int32_t output_offset); + * + * Canonical 256-element / 210-byte Q6_K super-block: `ql` low nibbles + + * `qh` high-2-bit plane + 16 int8 `scales` + FP16 `d`; each 6-bit code is + * dequantized as `d * scale * (code - 32)`. Numerical parity vs + * [PanamaVectorQ6_KMatmulKernel] is asserted by [NativeQ6KMatmulKernelParityTest]. + * + * Scalar 6-bit bit-assembly (`-O3 -ffast-math`, auto-vectorized) feeding a + * NEON dot product behind `__ARM_NEON`. + */ +internal object NativeQ6KMatmulKernel : Q6KMatmulKernel { + + private const val BLOCK_SIZE = 256 + private const val BYTES_PER_BLOCK = 210 + + fun isAvailable(): Boolean = handle != null + + override fun matmul( + input: FloatArray, inputOffset: Int, + weight: ByteArray, weightByteOffset: Int, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) { + require(inputDim % BLOCK_SIZE == 0) { + "NativeQ6KMatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim" + } + if (outputDim == 0 || inputDim == 0) return + val mh = handle + ?: error("NativeQ6KMatmulKernel.matmul invoked while native library unavailable") + + Arena.ofConfined().use { arena -> + val inSeg = arena.allocate( + inputDim.toLong() * java.lang.Float.BYTES, + ValueLayout.JAVA_FLOAT.byteAlignment(), + ) + val outSeg = arena.allocate( + outputDim.toLong() * java.lang.Float.BYTES, + ValueLayout.JAVA_FLOAT.byteAlignment(), + ) + val weightBytesUsed = ((inputDim / BLOCK_SIZE).toLong() * outputDim) * BYTES_PER_BLOCK.toLong() + val weightSeg = arena.allocate(weightBytesUsed, 1L) + + MemorySegment.copy(input, inputOffset, inSeg, ValueLayout.JAVA_FLOAT, 0L, inputDim) + MemorySegment.copy(weight, weightByteOffset, weightSeg, ValueLayout.JAVA_BYTE, 0L, weightBytesUsed.toInt()) + + mh.invoke( + inSeg, 0, + weightSeg, 0, + inputDim, outputDim, + outSeg, 0, + ) + + MemorySegment.copy(outSeg, ValueLayout.JAVA_FLOAT, 0L, output, outputOffset, outputDim) + } + } + + private val handle: MethodHandle? by lazy { + val lookup = NativeLibraryLoader.lookup() ?: return@lazy null + val symbol = lookup.find("skainet_q6k_matmul").orElse(null) ?: return@lazy null + val descriptor = FunctionDescriptor.ofVoid( + ValueLayout.ADDRESS, // input + ValueLayout.JAVA_INT, // input_offset + ValueLayout.ADDRESS, // weight + ValueLayout.JAVA_INT, // weight_byte_offset + ValueLayout.JAVA_INT, // input_dim + ValueLayout.JAVA_INT, // output_dim + ValueLayout.ADDRESS, // output + ValueLayout.JAVA_INT, // output_offset + ) + runCatching { Linker.nativeLinker().downcallHandle(symbol, descriptor) }.getOrNull() + } +} diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ6KMatmulKernelParityTest.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ6KMatmulKernelParityTest.kt new file mode 100644 index 00000000..2e5ad44a --- /dev/null +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ6KMatmulKernelParityTest.kt @@ -0,0 +1,95 @@ +package sk.ainet.exec.kernel + +import kotlin.math.abs +import kotlin.random.Random +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertTrue + +/** + * Numerical parity tests for [NativeQ6KMatmulKernel] against + * [PanamaVectorQ6_KMatmulKernel]. Both kernels share the canonical Q6_K + * layout (210-byte block: 128 B `ql` + 64 B `qh` + 16 B int8 `scales` + + * 2 B FP16 `d`) and dequant `d * scale * (code - 32)`, so outputs must + * agree element-wise within FMA + reordered-reduction tolerance. + * + * Fixture mirrors [NativeQ5KMatmulKernelParityTest]: random Q6_K bytes with + * `d` clamped to `1.0f16` (bytes 208-209), packed input-block-major + * `(blockIdx * outputDim + o) * 210`. Random `ql`/`qh`/`scales` exercise the + * 6-bit bit-assembly and the signed int8 scales. Q6_K magnitudes are larger + * than Q5_K (codes [-32, 31] × int8 scales), so absolute tolerances are a + * touch looser; the `rel < 1e-4` relative check is the real gate. + */ +class NativeQ6KMatmulKernelParityTest { + + private val blockSize = 256 + private val bytesPerBlock = 210 + + @BeforeTest + fun checkNativeAvailable() { + assertTrue( + NativeQ6KMatmulKernel.isAvailable(), + "NativeQ6KMatmulKernel reports unavailable on this host — bundled libskainet_kernels " + + "missing or skainet_q6k_matmul symbol unresolved", + ) + } + + private fun randomQ6KBytes(numBlocks: Int, seed: Int): ByteArray { + val rng = Random(seed) + val bytes = ByteArray(numBlocks * bytesPerBlock) + rng.nextBytes(bytes) + for (block in 0 until numBlocks) { + val base = block * bytesPerBlock + // 0x3C00 == 1.0f16 at the Q6_K `d` slot (bytes 208-209, LE). + bytes[base + 208] = 0x00.toByte() + bytes[base + 209] = 0x3C.toByte() + } + return bytes + } + + private fun assertParity(inputDim: Int, outputDim: Int, seed: Int, tol: Float) { + val numBlocks = (inputDim / blockSize) * outputDim + val packed = randomQ6KBytes(numBlocks, seed) + val input = FloatArray(inputDim) { Random(seed + it).nextFloat() - 0.5f } + + val refOut = FloatArray(outputDim) + PanamaVectorQ6_KMatmulKernel.matmul(input, 0, packed, 0, inputDim, outputDim, refOut, 0) + + val nativeOut = FloatArray(outputDim) + NativeQ6KMatmulKernel.matmul(input, 0, packed, 0, inputDim, outputDim, nativeOut, 0) + + for (o in 0 until outputDim) { + val diff = abs(refOut[o] - nativeOut[o]) + val rel = diff / (abs(refOut[o]) + 1e-9f) + assertTrue( + diff <= tol || rel < 1e-4f, + "row $o diverged: panama=${refOut[o]} native=${nativeOut[o]} diff=$diff rel=$rel tol=$tol", + ) + } + } + + @Test + fun single_block_single_row() = assertParity(256, 1, 42, 1e-2f) + + @Test + fun single_block_multi_row() = assertParity(256, 16, 7, 5e-2f) + + @Test + fun multi_block_multi_row() = assertParity(1024, 64, 123, 2e-1f) + + @Test + fun llm_typical_shape_4096_outputDim_64() = assertParity(4096, 64, 999, 2e0f) + + @Test + fun rejects_inputDim_not_multiple_of_block() { + val packed = randomQ6KBytes(numBlocks = 2, seed = 1) + val input = FloatArray(255) + val out = FloatArray(1) + try { + NativeQ6KMatmulKernel.matmul(input, 0, packed, 0, 255, 1, out, 0) + kotlin.test.fail("expected IllegalArgumentException for non-multiple inputDim") + } catch (e: IllegalArgumentException) { + // expected + } + } +} diff --git a/skainet-backends/skainet-backend-native-cpu/src/linuxX64Test/kotlin/sk/ainet/exec/kernel/NativeKnQ6KMatmulKernelParityTest.kt b/skainet-backends/skainet-backend-native-cpu/src/linuxX64Test/kotlin/sk/ainet/exec/kernel/NativeKnQ6KMatmulKernelParityTest.kt new file mode 100644 index 00000000..922f5f14 --- /dev/null +++ b/skainet-backends/skainet-backend-native-cpu/src/linuxX64Test/kotlin/sk/ainet/exec/kernel/NativeKnQ6KMatmulKernelParityTest.kt @@ -0,0 +1,70 @@ +package sk.ainet.exec.kernel + +import kotlin.math.abs +import kotlin.random.Random +import kotlin.test.Test +import kotlin.test.assertTrue + +/** + * Proves the Kotlin/Native cinterop path: [NativeKnQ6KMatmulKernel] (calling the + * C `skainet_q6k_matmul` via cinterop, linked from libskainet_kernels.a) must + * agree with the commonMain [ScalarQ6_KMatmulKernel] reference within FMA + + * `-ffast-math` reassociation tolerance. + * + * This is the host (linuxX64) de-risking of the board (linuxArm64) consumption: + * the cinterop mechanism + kernel correctness are verified here; only the NEON + * codegen differs on aarch64 (board-verify-pending). Q6_K magnitudes (codes + * [-32, 31] × signed int8 scales) are larger than Q5_K, so absolute tolerances + * are a touch looser; the `rel < 1e-4` relative check is the real gate. + */ +class NativeKnQ6KMatmulKernelParityTest { + + private val blockSize = 256 + private val bytesPerBlock = 210 + + private fun randomQ6KBytes(numBlocks: Int, seed: Int): ByteArray { + val rng = Random(seed) + val bytes = ByteArray(numBlocks * bytesPerBlock) + rng.nextBytes(bytes) + for (block in 0 until numBlocks) { + val base = block * bytesPerBlock + // 0x3C00 == 1.0f16 at the Q6_K `d` slot (bytes 208-209, LE). + bytes[base + 208] = 0x00.toByte() + bytes[base + 209] = 0x3C.toByte() + } + return bytes + } + + private fun assertParity(inputDim: Int, outputDim: Int, seed: Int, tol: Float) { + val numBlocks = (inputDim / blockSize) * outputDim + val packed = randomQ6KBytes(numBlocks, seed) + val input = FloatArray(inputDim) { Random(seed + it).nextFloat() - 0.5f } + + val refOut = FloatArray(outputDim) + ScalarQ6_KMatmulKernel.matmul(input, 0, packed, 0, inputDim, outputDim, refOut, 0) + + val knOut = FloatArray(outputDim) + NativeKnQ6KMatmulKernel.matmul(input, 0, packed, 0, inputDim, outputDim, knOut, 0) + + for (o in 0 until outputDim) { + val diff = abs(refOut[o] - knOut[o]) + val rel = diff / (abs(refOut[o]) + 1e-9f) + assertTrue( + diff <= tol || rel < 1e-4f, + "row $o diverged: scalar=${refOut[o]} cinterop=${knOut[o]} diff=$diff rel=$rel tol=$tol", + ) + } + } + + @Test + fun single_block_single_row() = assertParity(256, 1, 42, 1e-2f) + + @Test + fun single_block_multi_row() = assertParity(256, 16, 7, 5e-2f) + + @Test + fun multi_block_multi_row() = assertParity(1024, 64, 123, 2e-1f) + + @Test + fun llm_typical_shape() = assertParity(4096, 64, 999, 2e0f) +} diff --git a/skainet-backends/skainet-backend-native-cpu/src/nativeMain/kotlin/sk/ainet/exec/kernel/NativeKnKernelProvider.kt b/skainet-backends/skainet-backend-native-cpu/src/nativeMain/kotlin/sk/ainet/exec/kernel/NativeKnKernelProvider.kt index 83f80c36..5bc47def 100644 --- a/skainet-backends/skainet-backend-native-cpu/src/nativeMain/kotlin/sk/ainet/exec/kernel/NativeKnKernelProvider.kt +++ b/skainet-backends/skainet-backend-native-cpu/src/nativeMain/kotlin/sk/ainet/exec/kernel/NativeKnKernelProvider.kt @@ -10,9 +10,11 @@ import sk.ainet.backend.api.kernel.KernelRegistry import sk.ainet.backend.api.kernel.Q4KMatmulKernel import sk.ainet.backend.api.kernel.Q4_0MatmulKernel import sk.ainet.backend.api.kernel.Q5KMatmulKernel +import sk.ainet.backend.api.kernel.Q6KMatmulKernel import sk.ainet.backend.api.kernel.Q8_0MatmulKernel import sk.ainet.kernels.cinterop.skainet_q4_0_matmul import sk.ainet.kernels.cinterop.skainet_q4k_matmul +import sk.ainet.kernels.cinterop.skainet_q6k_matmul import sk.ainet.kernels.cinterop.skainet_q8_0_matmul /** @@ -23,8 +25,8 @@ import sk.ainet.kernels.cinterop.skainet_q8_0_matmul * * **Registration is manual on K/N** (no `ServiceLoader`): a consumer calls * [installNativeKernels] once at startup. [Q5KMatmulKernel] (the FunctionGemma - * Q5_K_M hot path) plus Q4_K / Q8_0 / Q4_0 are wired; the rest cascade to the - * scalar provider. + * Q5_K_M hot path) plus Q4_K / Q6_K / Q8_0 / Q4_0 are wired; the rest cascade to + * the scalar provider. */ @OptIn(ExperimentalForeignApi::class) public object NativeKnKernelProvider : KernelProvider { @@ -40,6 +42,7 @@ public object NativeKnKernelProvider : KernelProvider { override fun matmulQ5K(): Q5KMatmulKernel = NativeKnQ5KMatmulKernel override fun matmulQ4K(): Q4KMatmulKernel = NativeKnQ4KMatmulKernel + override fun matmulQ6K(): Q6KMatmulKernel = NativeKnQ6KMatmulKernel override fun matmulQ8_0(): Q8_0MatmulKernel = NativeKnQ8_0MatmulKernel override fun matmulQ4_0(): Q4_0MatmulKernel = NativeKnQ4_0MatmulKernel } @@ -49,7 +52,7 @@ public object NativeKnKernelProvider : KernelProvider { * (re-registering the same instance is a no-op). Call once at startup before any * `ops.matmul` on quantized weights. * - * For quant types without a C kernel (e.g. Q6_K) also register the commonMain + * For quant types without a C kernel also register the commonMain * `ScalarKernelProvider` (from `skainet-backend-cpu`) as the fallback — it lives * in a different module, so the consumer wires it: * `KernelRegistry.register(ScalarKernelProvider)`. @@ -82,6 +85,30 @@ public object NativeKnQ4KMatmulKernel : Q4KMatmulKernel { } } +@OptIn(ExperimentalForeignApi::class) +public object NativeKnQ6KMatmulKernel : Q6KMatmulKernel { + private const val BLOCK_SIZE = 256 + override fun matmul( + input: FloatArray, inputOffset: Int, + weight: ByteArray, weightByteOffset: Int, + inputDim: Int, outputDim: Int, + output: FloatArray, outputOffset: Int, + ) { + require(inputDim % BLOCK_SIZE == 0) { + "NativeKnQ6KMatmulKernel: inputDim must be a multiple of $BLOCK_SIZE; got $inputDim" + } + if (outputDim == 0 || inputDim == 0) return + input.usePinned { i -> weight.usePinned { w -> output.usePinned { o -> + skainet_q6k_matmul( + i.addressOf(0), inputOffset, + w.addressOf(0).reinterpret(), weightByteOffset, + inputDim, outputDim, + o.addressOf(0), outputOffset, + ) + } } } + } +} + @OptIn(ExperimentalForeignApi::class) public object NativeKnQ8_0MatmulKernel : Q8_0MatmulKernel { private const val BLOCK_SIZE = 32