From a24f21d0ec18c2a43b22743592bc9101780a6ae6 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Wed, 17 Jun 2026 22:58:15 +0200 Subject: [PATCH] feat(lang/ops): RowDequantSource in the engine + ops.gather row-dequant path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generalizes the per-row dequant trick out of the model layer (SKaiNET-transformers issue #184, hoist 1). Adds `RowDequantSource` (a `TensorData` marker: `dequantRow(rowIdx): FloatArray`) to skainet-lang-core, and teaches `DefaultCpuOps.gather` to use it: when the gathered table implements RowDequantSource, dequantise only the touched rows (each unique row once, cached) instead of the generic element path — which calls `get()`, unsupported on such tensors, and would otherwise force a full FP32 materialise of the table. A RowDequantSource table declares logical dtype FP32, so gather returns FP32 with no typing change. This lets a packed/oversized embedding (e.g. a Q-quantised token_embd) stay packed and be looked up via ops.gather directly — the basis for keeping Gemma's ~0.67 GB token_embd packed (#178's remaining board-fit item) and, later, whisper int8. Verified: new GatherRowDequantTest — gather over a fake RowDequantSource table whose get()/copyToFloatArray() throw returns the correct dequantised rows (so it provably went through dequantRow). backend-cpu compiles + the test passes. Next (separate, release-coordinated): SKaiNET-transformers re-points gemma's RowDequantSource to this engine interface (typealias) and routes token_embd through ops.gather; the GemmaQ5KPackedParityTest is the end-to-end gate. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../sk/ainet/exec/tensor/ops/DefaultCpuOps.kt | 26 ++++++++--- .../exec/tensor/ops/GatherRowDequantTest.kt | 45 +++++++++++++++++++ .../lang/tensor/data/RowDequantSource.kt | 19 ++++++++ 3 files changed, 85 insertions(+), 5 deletions(-) create mode 100644 skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/tensor/ops/GatherRowDequantTest.kt create mode 100644 skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/RowDequantSource.kt diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt index 10926261..6aad624e 100644 --- a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt @@ -10,6 +10,7 @@ import sk.ainet.lang.ops.TensorOp import sk.ainet.lang.ops.InProgress import sk.ainet.backend.api.kernel.KernelProvider import sk.ainet.backend.api.kernel.KernelRegistry +import sk.ainet.lang.tensor.data.RowDequantSource import sk.ainet.lang.tensor.data.FloatArrayTensorData import sk.ainet.lang.tensor.data.IntArrayTensorData import sk.ainet.lang.tensor.data.Q4_0TensorData @@ -2633,8 +2634,8 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory // Preserve index shape + embedding dim Shape(IntArray(indices.rank) { indices.shape[it] } + intArrayOf(embDim)) } - val outData = dataFactory.init(outShape, input.dtype) { outIdx -> - // Map multi-dim output index to flat index and embedding position + fun rowOf(outIdx: IntArray): Int { + // Map multi-dim output index to the flat index into the index list. val flatIdx = if (outIdx.size == 2) outIdx[0] else { var flat = 0 for (d in 0 until outIdx.size - 1) { @@ -2642,9 +2643,24 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory } flat } - val row = indexList[flatIdx] - val col = outIdx[outIdx.size - 1] - input.data[row, col] + return indexList[flatIdx] + } + val src = input.data + val outData = if (src is RowDequantSource) { + // Packed / oversized table (e.g. a Q-quantised embedding): dequantise only the rows + // actually touched — never materialise the whole table, never call get() (unsupported on + // such tensors). Each unique row is dequantised once; logical dtype is FP32. + val rowCache = HashMap() + dataFactory.init(outShape, input.dtype) { outIdx -> + val row = rowOf(outIdx) + val col = outIdx[outIdx.size - 1] + @Suppress("UNCHECKED_CAST") + (rowCache.getOrPut(row) { src.dequantRow(row) }[col] as V) + } + } else { + dataFactory.init(outShape, input.dtype) { outIdx -> + input.data[rowOf(outIdx), outIdx[outIdx.size - 1]] + } } return newTensor(outData, input.dtype, input) } diff --git a/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/tensor/ops/GatherRowDequantTest.kt b/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/tensor/ops/GatherRowDequantTest.kt new file mode 100644 index 00000000..80bebefb --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/exec/tensor/ops/GatherRowDequantTest.kt @@ -0,0 +1,45 @@ +package sk.ainet.exec.tensor.ops + +import sk.ainet.context.DirectCpuExecutionContext +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.RowDequantSource +import sk.ainet.lang.tensor.data.TensorData +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int32 +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals + +/** + * `ops.gather` on a [RowDequantSource] table must dequantise only the touched rows — never materialise the + * whole table and never call `get()` (which such tensors don't support). The fake table below throws from + * `get`/`set`, so the test passes only if gather went through [RowDequantSource.dequantRow]. + */ +class GatherRowDequantTest { + + /** A 4×3 "packed" table: row r dequants to [r*10, r*10+1, r*10+2]. Element access is unsupported. */ + private class FakeRowDequantTable : TensorData, RowDequantSource { + override val shape: Shape = Shape(4, 3) + override fun dequantRow(rowIdx: Int): FloatArray = FloatArray(3) { rowIdx * 10f + it } + override fun get(vararg indices: Int): Float = error("get() must not be called — use dequantRow()") + override fun set(vararg indices: Int, value: Float) = error("set() unsupported") + override fun copyToFloatArray(): FloatArray = error("copyToFloatArray() must not be called") + } + + @Test + fun gatherDequantsTouchedRowsOnly() { + val ctx = DirectCpuExecutionContext.create() + val table = ctx.fromData(FakeRowDequantTable(), FP32::class) + val ids = ctx.fromIntArray(Shape(3), Int32::class, intArrayOf(2, 0, 3)) + + @Suppress("UNCHECKED_CAST") + val out = ctx.ops.gather(table, ids as Tensor, dim = 0) + + assertEquals(listOf(3, 3), out.shape.dimensions.toList()) + assertContentEquals( + floatArrayOf(20f, 21f, 22f, /* row 2 */ 0f, 1f, 2f, /* row 0 */ 30f, 31f, 32f /* row 3 */), + out.data.copyToFloatArray(), + ) + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/RowDequantSource.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/RowDequantSource.kt new file mode 100644 index 00000000..cb3af873 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/data/RowDequantSource.kt @@ -0,0 +1,19 @@ +package sk.ainet.lang.tensor.data + +/** + * Marker for a 2-D [TensorData] whose rows can be **dequantised on demand**, for tables that cannot (or + * should not) be materialised as a single dense `FloatArray` — e.g. a packed-quant embedding whose logical + * size exceeds `Int.MAX_VALUE` elements / 2 GB, or one kept packed to save memory. + * + * Such a tensor declares its **logical** dtype `FP32` (the dequantised value type); its packed bytes are an + * internal storage detail, and `get`/`copyToFloatArray()` are typically unsupported. Ops that read whole + * rows — primarily **embedding lookup** (`ops.gather` / `ops.indexSelect`, `dim = 0`, indices = token ids) + * — MUST use [dequantRow] instead of element access, dequantising only the rows actually touched. + * + * This is the engine-level home of the contract; model-specific implementations (e.g. a GGUF Q6_K / + * SafeTensors BF16 embedding) provide [dequantRow] over their own packed source. + */ +public interface RowDequantSource { + /** Dequantise logical row [rowIdx] (`0 until shape[0]`) to a fresh `FloatArray` of length `shape[1]`. */ + public fun dequantRow(rowIdx: Int): FloatArray +}