Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import sk.ainet.lang.ops.TensorOp
import sk.ainet.lang.ops.InProgress
import sk.ainet.backend.api.kernel.KernelProvider
import sk.ainet.backend.api.kernel.KernelRegistry
import sk.ainet.lang.tensor.data.RowDequantSource
import sk.ainet.lang.tensor.data.FloatArrayTensorData
import sk.ainet.lang.tensor.data.IntArrayTensorData
import sk.ainet.lang.tensor.data.Q4_0TensorData
Expand Down Expand Up @@ -2633,18 +2634,33 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory
// Preserve index shape + embedding dim
Shape(IntArray(indices.rank) { indices.shape[it] } + intArrayOf(embDim))
}
val outData = dataFactory.init<T, V>(outShape, input.dtype) { outIdx ->
// Map multi-dim output index to flat index and embedding position
fun rowOf(outIdx: IntArray): Int {
// Map multi-dim output index to the flat index into the index list.
val flatIdx = if (outIdx.size == 2) outIdx[0] else {
var flat = 0
for (d in 0 until outIdx.size - 1) {
flat = flat * (if (d < indices.rank) indices.shape[d] else 1) + outIdx[d]
}
flat
}
val row = indexList[flatIdx]
val col = outIdx[outIdx.size - 1]
input.data[row, col]
return indexList[flatIdx]
}
val src = input.data
val outData = if (src is RowDequantSource) {
// Packed / oversized table (e.g. a Q-quantised embedding): dequantise only the rows
// actually touched — never materialise the whole table, never call get() (unsupported on
// such tensors). Each unique row is dequantised once; logical dtype is FP32.
val rowCache = HashMap<Int, FloatArray>()
dataFactory.init<T, V>(outShape, input.dtype) { outIdx ->
val row = rowOf(outIdx)
val col = outIdx[outIdx.size - 1]
@Suppress("UNCHECKED_CAST")
(rowCache.getOrPut(row) { src.dequantRow(row) }[col] as V)
}
} else {
dataFactory.init<T, V>(outShape, input.dtype) { outIdx ->
input.data[rowOf(outIdx), outIdx[outIdx.size - 1]]
}
}
return newTensor(outData, input.dtype, input)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package sk.ainet.exec.tensor.ops

import sk.ainet.context.DirectCpuExecutionContext
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.tensor.data.RowDequantSource
import sk.ainet.lang.tensor.data.TensorData
import sk.ainet.lang.types.FP32
import sk.ainet.lang.types.Int32
import kotlin.test.Test
import kotlin.test.assertContentEquals
import kotlin.test.assertEquals

/**
* `ops.gather` on a [RowDequantSource] table must dequantise only the touched rows — never materialise the
* whole table and never call `get()` (which such tensors don't support). The fake table below throws from
* `get`/`set`, so the test passes only if gather went through [RowDequantSource.dequantRow].
*/
class GatherRowDequantTest {

/** A 4×3 "packed" table: row r dequants to [r*10, r*10+1, r*10+2]. Element access is unsupported. */
private class FakeRowDequantTable : TensorData<FP32, Float>, RowDequantSource {
override val shape: Shape = Shape(4, 3)
override fun dequantRow(rowIdx: Int): FloatArray = FloatArray(3) { rowIdx * 10f + it }
override fun get(vararg indices: Int): Float = error("get() must not be called — use dequantRow()")
override fun set(vararg indices: Int, value: Float) = error("set() unsupported")
override fun copyToFloatArray(): FloatArray = error("copyToFloatArray() must not be called")
}

@Test
fun gatherDequantsTouchedRowsOnly() {
val ctx = DirectCpuExecutionContext.create()
val table = ctx.fromData<FP32, Float>(FakeRowDequantTable(), FP32::class)
val ids = ctx.fromIntArray<Int32, Int>(Shape(3), Int32::class, intArrayOf(2, 0, 3))

@Suppress("UNCHECKED_CAST")
val out = ctx.ops.gather(table, ids as Tensor<sk.ainet.lang.types.DType, *>, dim = 0)

assertEquals(listOf(3, 3), out.shape.dimensions.toList())
assertContentEquals(
floatArrayOf(20f, 21f, 22f, /* row 2 */ 0f, 1f, 2f, /* row 0 */ 30f, 31f, 32f /* row 3 */),
out.data.copyToFloatArray(),
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package sk.ainet.lang.tensor.data

/**
* Marker for a 2-D [TensorData] whose rows can be **dequantised on demand**, for tables that cannot (or
* should not) be materialised as a single dense `FloatArray` — e.g. a packed-quant embedding whose logical
* size exceeds `Int.MAX_VALUE` elements / 2 GB, or one kept packed to save memory.
*
* Such a tensor declares its **logical** dtype `FP32` (the dequantised value type); its packed bytes are an
* internal storage detail, and `get`/`copyToFloatArray()` are typically unsupported. Ops that read whole
* rows — primarily **embedding lookup** (`ops.gather` / `ops.indexSelect`, `dim = 0`, indices = token ids)
* — MUST use [dequantRow] instead of element access, dequantising only the rows actually touched.
*
* This is the engine-level home of the contract; model-specific implementations (e.g. a GGUF Q6_K /
* SafeTensors BF16 embedding) provide [dequantRow] over their own packed source.
*/
public interface RowDequantSource {
/** Dequantise logical row [rowIdx] (`0 until shape[0]`) to a fresh `FloatArray` of length `shape[1]`. */
public fun dequantRow(rowIdx: Int): FloatArray
}
Loading