diff --git a/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/sk/ainet/exec/tensor/ops/GruTest.kt b/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/sk/ainet/exec/tensor/ops/GruTest.kt new file mode 100644 index 00000000..33d18397 --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/sk/ainet/exec/tensor/ops/GruTest.kt @@ -0,0 +1,87 @@ +package sk.ainet.sk.ainet.exec.tensor.ops + +import kotlin.math.exp +import kotlin.math.tanh +import kotlin.test.Test +import kotlin.test.assertEquals +import sk.ainet.context.DirectCpuExecutionContext +import sk.ainet.lang.nn.Gru +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.types.FP32 + +/** + * Eager forward correctness for the GRU layer, checked against an independent + * scalar reference (raw FloatArray loops) computing the same PyTorch r,z,n + * recurrence. Exercises all three gates, the per-timestep hidden feedback + * (weight_hh path) and the update blend. + */ +class GruTest { + private val ctx = DirectCpuExecutionContext() + + private fun sigmoid(x: Float): Float = (1.0 / (1.0 + exp(-x.toDouble()))).toFloat() + + // Independent reference: input [S*D] (batch=1), weights matmul-ready (row-major). + private fun gruRef( + x: FloatArray, seq: Int, d: Int, h: Int, + wIh: FloatArray, wHh: FloatArray, bIh: FloatArray, bHh: FloatArray, + ): FloatArray { + val g = 3 * h + val hidden = FloatArray(h) + val out = FloatArray(seq * h) + for (t in 0 until seq) { + val gx = FloatArray(g) { k -> bIh[k] + (0 until d).sumOf { i -> (x[t * d + i] * wIh[i * g + k]).toDouble() }.toFloat() } + val gh = FloatArray(g) { k -> bHh[k] + (0 until h).sumOf { j -> (hidden[j] * wHh[j * g + k]).toDouble() }.toFloat() } + for (j in 0 until h) { + val r = sigmoid(gx[j] + gh[j]) + val z = sigmoid(gx[h + j] + gh[h + j]) + val n = tanh((gx[2 * h + j] + r * gh[2 * h + j]).toDouble()).toFloat() + hidden[j] = (1f - z) * n + z * hidden[j] + } + for (j in 0 until h) out[t * h + j] = hidden[j] + } + return out + } + + @Test + fun gru_forward_matches_reference() { + val batch = 1; val seq = 2; val d = 2; val h = 2; val g = 3 * h + // deterministic small weights/inputs (kept in sigmoid/tanh's sensitive range) + val x = FloatArray(seq * d) { ((it % 5) - 2) * 0.3f } + val wIh = FloatArray(d * g) { ((it % 7) - 3) * 0.1f } + val wHh = FloatArray(h * g) { ((it % 5) - 2) * 0.15f } + val bIh = FloatArray(g) { ((it % 3) - 1) * 0.2f } + val bHh = FloatArray(g) { ((it % 4) - 2) * 0.1f } + + val gru = Gru( + inputSize = d, hiddenSize = h, name = "gru", + initWeightIh = ctx.fromFloatArray(Shape(d, g), FP32::class, wIh), + initWeightHh = ctx.fromFloatArray(Shape(h, g), FP32::class, wHh), + initBiasIh = ctx.fromFloatArray(Shape(g), FP32::class, bIh), + initBiasHh = ctx.fromFloatArray(Shape(g), FP32::class, bHh), + ) + + val input = ctx.fromFloatArray(Shape(batch, seq, d), FP32::class, x) + val out = gru.forward(input, ctx) + assertEquals(Shape(batch, seq, h), out.shape) + + val expected = gruRef(x, seq, d, h, wIh, wHh, bIh, bHh) + for (t in 0 until seq) for (j in 0 until h) { + assertEquals(expected[t * h + j], out.data[0, t, j], 1e-5f) + } + } + + @Test + fun gru_output_shape_is_batch_seq_hidden() { + val batch = 2; val seq = 3; val d = 4; val h = 5; val g = 3 * h + val gru = Gru( + inputSize = d, hiddenSize = h, name = "gru", + initWeightIh = ctx.fromFloatArray(Shape(d, g), FP32::class, FloatArray(d * g) { (it % 9 - 4) * 0.05f }), + initWeightHh = ctx.fromFloatArray(Shape(h, g), FP32::class, FloatArray(h * g) { (it % 7 - 3) * 0.05f }), + initBiasIh = ctx.fromFloatArray(Shape(g), FP32::class, FloatArray(g) { 0f }), + initBiasHh = ctx.fromFloatArray(Shape(g), FP32::class, FloatArray(g) { 0f }), + ) + val input = ctx.fromFloatArray(Shape(batch, seq, d), FP32::class, FloatArray(batch * seq * d) { (it % 11 - 5) * 0.1f }) + val out = gru.forward(input, ctx) + assertEquals(Shape(batch, seq, h), out.shape) + } +} diff --git a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/ConvPoolBackwardTest.kt b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/ConvPoolBackwardTest.kt index f8d0d00f..36e29360 100644 --- a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/ConvPoolBackwardTest.kt +++ b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/ConvPoolBackwardTest.kt @@ -13,6 +13,7 @@ import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.tensor.data.DenseTensorDataFactory import sk.ainet.lang.tensor.data.FloatArrayTensorData +import sk.ainet.lang.nn.Gru import sk.ainet.lang.tensor.ops.UpsampleMode import sk.ainet.lang.tensor.withRequiresGrad import sk.ainet.lang.trace.GraphSink @@ -232,4 +233,28 @@ class ConvPoolBackwardTest { x.ops.upsample2d(x, scale = 2 to 2, mode = UpsampleMode.Bilinear, alignCorners = false) } } + + @Test + fun gru_backward_input_matches_finite_diff() { + // Single-layer GRU, input [batch=1, seq=2, in=2], hidden=2. The composed/unrolled + // cell must propagate gradients through every gate back to the input. + val d = 2; val h = 2; val g = 3 * h + val wIh = FloatArray(d * g) { ((it % 7) - 3) * 0.1f } + val wHh = FloatArray(h * g) { ((it % 5) - 2) * 0.1f } + val bIh = FloatArray(g) { ((it % 3) - 1) * 0.1f } + val bHh = FloatArray(g) { ((it % 4) - 2) * 0.05f } + assertGradMatchesFiniteDiff( + xShape = Shape(1, 2, d), + x0 = floatArrayOf(0.3f, -0.2f, 0.1f, 0.4f), + ) { c, x -> + val gru = Gru( + inputSize = d, hiddenSize = h, name = "gru", + initWeightIh = floatTensor(c, Shape(d, g), wIh), + initWeightHh = floatTensor(c, Shape(h, g), wHh), + initBiasIh = floatTensor(c, Shape(g), bIh), + initBiasHh = floatTensor(c, Shape(g), bHh), + ) + gru.forward(x, c) + } + } } diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/Gru.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/Gru.kt new file mode 100644 index 00000000..1211398b --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/Gru.kt @@ -0,0 +1,118 @@ +package sk.ainet.lang.nn + +import sk.ainet.context.ExecutionContext +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.types.DType +import sk.ainet.lang.nn.topology.ModuleParameter +import sk.ainet.lang.nn.topology.ModuleParameters + +/** + * Single-layer, unidirectional, batch-first GRU (gated recurrent unit). + * + * Input `[batch, seq, inputSize]` -> output `[batch, seq, hiddenSize]` (all hidden states). + * + * The recurrence is **unrolled over the (static) sequence length at trace time** and built + * entirely from existing primitive ops (matmul / add / sigmoid / tanh / multiply / narrow / + * concat), so it runs in the eager engine, is trainable through the standard autodiff tape, + * and traces to StableHLO with no dedicated converter (StableHLO has no loop construct). + * + * Gate math matches `torch.nn.GRU` (gate order reset, update, new), so PyTorch weights load + * directly (after transposing to the matmul-ready orientation used here): + * + * r = sigmoid(x·W_ir + b_ir + h·W_hr + b_hr) + * z = sigmoid(x·W_iz + b_iz + h·W_hz + b_hz) + * n = tanh (x·W_in + b_in + r ⊙ (h·W_hn + b_hn)) + * h' = (1 - z) ⊙ n + z ⊙ h + * + * Weights are stored **matmul-ready** (input-major) — the three gates are concatenated on the + * trailing axis so a single matmul produces all three pre-activations: + * - [weightIh] `[inputSize, 3*hiddenSize]`, [weightHh] `[hiddenSize, 3*hiddenSize]` + * - [biasIh] / [biasHh] `[3*hiddenSize]` (gate order r, z, n) + * + * @param inputSize number of input features + * @param hiddenSize size of the hidden state + */ +public class Gru @kotlin.jvm.JvmOverloads constructor( + public val inputSize: Int, + public val hiddenSize: Int, + override val name: String = "Gru", + initWeightIh: Tensor, + initWeightHh: Tensor, + initBiasIh: Tensor, + initBiasHh: Tensor, + public val trainable: Boolean = true, +) : Module(), ModuleParameters { + + init { + require(inputSize > 0) { "Gru($name): inputSize must be positive, was $inputSize" } + require(hiddenSize > 0) { "Gru($name): hiddenSize must be positive, was $hiddenSize" } + val g = 3 * hiddenSize + fun check2d(t: Tensor, rows: Int, cols: Int, what: String) { + val s = t.shape.dimensions + require(t.rank == 2 && s[0] == rows && s[1] == cols) { + "Gru($name): $what shape must be [$rows, $cols], but was ${t.shape}" + } + } + check2d(initWeightIh, inputSize, g, "weightIh") + check2d(initWeightHh, hiddenSize, g, "weightHh") + fun check1d(t: Tensor, len: Int, what: String) { + require(t.rank == 1 && t.shape.dimensions[0] == len) { + "Gru($name): $what shape must be [$len], but was ${t.shape}" + } + } + check1d(initBiasIh, g, "biasIh") + check1d(initBiasHh, g, "biasHh") + } + + private val pWeightIh = ModuleParameter.WeightParameter("$name.weight_ih", initWeightIh, trainable) + private val pWeightHh = ModuleParameter.WeightParameter("$name.weight_hh", initWeightHh, trainable) + private val pBiasIh = ModuleParameter.BiasParameter("$name.bias_ih", initBiasIh, trainable) + private val pBiasHh = ModuleParameter.BiasParameter("$name.bias_hh", initBiasHh, trainable) + + override val params: List> = listOf(pWeightIh, pWeightHh, pBiasIh, pBiasHh) + + override val modules: List> + get() = emptyList() + + override fun onForward(input: Tensor, ctx: ExecutionContext): Tensor { + require(input.rank == 3) { + "Gru($name): input must be 3D [batch, seq, inputSize], but was ${input.shape}" + } + val ops = ctx.ops + val batch = input.shape[0] + val seq = input.shape[1] + val h = hiddenSize + + val weightIh = pWeightIh.value + val weightHh = pWeightHh.value + val biasIh = pBiasIh.value + val biasHh = pBiasHh.value + + // Initial hidden state h0 = 0 (a constant leaf in the trace). + var hidden = ctx.zeros(Shape(batch, h), input.dtype) + + val outputs = ArrayList>(seq) + for (t in 0 until seq) { + // x_t : [batch, inputSize] + val xt = ops.reshape(ops.narrow(input, 1, t, 1), Shape(batch, inputSize)) + // gate pre-activations: [batch, 3H] + val gx = ops.add(ops.matmul(xt, weightIh), biasIh) + val gh = ops.add(ops.matmul(hidden, weightHh), biasHh) + // split into reset / update / new on the gate axis + val xr = ops.narrow(gx, 1, 0, h); val hr = ops.narrow(gh, 1, 0, h) + val xz = ops.narrow(gx, 1, h, h); val hz = ops.narrow(gh, 1, h, h) + val xn = ops.narrow(gx, 1, 2 * h, h); val hn = ops.narrow(gh, 1, 2 * h, h) + + val r = ops.sigmoid(ops.add(xr, hr)) + val z = ops.sigmoid(ops.add(xz, hz)) + val n = ops.tanh(ops.add(xn, ops.multiply(r, hn))) + // h' = (1 - z) ⊙ n + z ⊙ h + val oneMinusZ = ops.rsubScalar(1.0, z) + hidden = ops.add(ops.multiply(oneMinusZ, n), ops.multiply(z, hidden)) + + outputs.add(ops.unsqueeze(hidden, 1)) // [batch, 1, H] + } + return ops.concat(outputs, 1) // [batch, seq, H] + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt index 4605d2af..bab5b93e 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt @@ -6,6 +6,7 @@ import sk.ainet.lang.nn.Conv1d import sk.ainet.lang.nn.Conv2d import sk.ainet.lang.nn.Conv3d import sk.ainet.lang.nn.Flatten +import sk.ainet.lang.nn.Gru import sk.ainet.lang.nn.Input import sk.ainet.lang.nn.Linear import sk.ainet.lang.nn.MaxPool2d @@ -153,6 +154,17 @@ public interface NeuralNetworkDsl : NetworkDslItem { */ public fun dense(id: String = "", content: DENSE.() -> Unit = {}) + /** + * Creates a single-layer unidirectional GRU. Input `[batch, seq, features]` -> + * output `[batch, seq, hiddenSize]`. Input feature size is inferred from the + * preceding layer's output dimension. + * + * @param hiddenSize size of the GRU hidden state + * @param id Optional identifier for the layer + * @param content Configuration block (e.g. `trainable`) + */ + public fun gru(hiddenSize: Int, id: String = "", content: GRU.() -> Unit = {}) + /** * Creates a dense layer with precision override and specified output dimension. * This allows individual layers to use different precision than the network default. @@ -598,6 +610,37 @@ public class FlattenImpl( } } +/** DSL config handle for a [Gru] layer. */ +public interface GRU : NetworkDslItem { + public var trainable: Boolean +} + +public class GruImpl( + override val executionContext: ExecutionContext, + private val inputSize: Int, + private val hiddenSize: Int, + private val id: String, + private val kClass: KClass, +) : GRU { + override var trainable: Boolean = true + + public fun create(): Gru { + require(inputSize > 0) { "Gru inputSize must be > 0 (declare an input shape before gru)." } + require(hiddenSize > 0) { "Gru hiddenSize must be > 0." } + val g = 3 * hiddenSize + return Gru( + inputSize = inputSize, + hiddenSize = hiddenSize, + name = id, + initWeightIh = executionContext.placeholder(Shape(inputSize, g), kClass), + initWeightHh = executionContext.placeholder(Shape(hiddenSize, g), kClass), + initBiasIh = executionContext.placeholder(Shape(g), kClass), + initBiasHh = executionContext.placeholder(Shape(g), kClass), + trainable = trainable, + ) + } +} + private fun createLinear( executionContext: ExecutionContext, inFeatures: Int, @@ -1237,6 +1280,21 @@ public class StageImpl( modules += impl.create() } + override fun gru(hiddenSize: Int, id: String, content: GRU.() -> Unit) { + val inputSize = lastDimension + lastDimension = hiddenSize + currentShape = intArrayOf(hiddenSize) + val impl = GruImpl( + executionContext, + inputSize = inputSize, + hiddenSize = hiddenSize, + id = getDefaultName(id, "gru", modules.size), + kClass = kClass, + ) + impl.content() + modules += impl.create() + } + override fun activation(id: String, activation: (Tensor) -> Tensor) { modules += ActivationsWrapperModule(activation, getDefaultName(id, "activation", modules.size)) } @@ -1616,6 +1674,21 @@ public class NeuralNetworkDslImpl( modules += impl.create() } + override fun gru(hiddenSize: Int, id: String, content: GRU.() -> Unit) { + val inputSize = lastDimension + lastDimension = hiddenSize + currentShape = intArrayOf(hiddenSize) + val impl = GruImpl( + executionContext, + inputSize = inputSize, + hiddenSize = hiddenSize, + id = getDefaultName(id, "gru", modules.size), + kClass = kClass, + ) + impl.content() + modules += impl.create() + } + override fun activation(id: String, activation: (Tensor) -> Tensor) { modules += ActivationsWrapperModule(activation, getDefaultName(id, "activation", modules.size)) } diff --git a/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/nn/GruDslTest.kt b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/nn/GruDslTest.kt new file mode 100644 index 00000000..c415b971 --- /dev/null +++ b/skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/nn/GruDslTest.kt @@ -0,0 +1,28 @@ +package sk.ainet.lang.nn + +import sk.ainet.lang.types.FP32 +import kotlin.test.Test +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +/** The `gru(...)` network-DSL builder wires a [Gru] layer into the model. */ +class GruDslTest { + + @Test + fun gru_dsl_builds_a_gru_layer() { + val model = definition { + network { + input(4) // input feature size D + gru(8) // hidden size H + } + } + assertNotNull(model) + // The built model tree must contain a Gru module configured D=4 -> H=8. + val grus = flattenModules(model).filterIsInstance>() + assertTrue(grus.isNotEmpty(), "network { gru(8) } must produce a Gru module") + assertTrue(grus.any { it.inputSize == 4 && it.hiddenSize == 8 }) + } + + private fun flattenModules(m: Module): List> = + listOf(m) + m.modules.flatMap { flattenModules(it) } +}