Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package sk.ainet.sk.ainet.exec.tensor.ops

import kotlin.math.exp
import kotlin.math.tanh
import kotlin.test.Test
import kotlin.test.assertEquals
import sk.ainet.context.DirectCpuExecutionContext
import sk.ainet.lang.nn.Gru
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.types.FP32

/**
* Eager forward correctness for the GRU layer, checked against an independent
* scalar reference (raw FloatArray loops) computing the same PyTorch r,z,n
* recurrence. Exercises all three gates, the per-timestep hidden feedback
* (weight_hh path) and the update blend.
*/
class GruTest {
private val ctx = DirectCpuExecutionContext()

private fun sigmoid(x: Float): Float = (1.0 / (1.0 + exp(-x.toDouble()))).toFloat()

// Independent reference: input [S*D] (batch=1), weights matmul-ready (row-major).
private fun gruRef(
x: FloatArray, seq: Int, d: Int, h: Int,
wIh: FloatArray, wHh: FloatArray, bIh: FloatArray, bHh: FloatArray,
): FloatArray {
val g = 3 * h
val hidden = FloatArray(h)
val out = FloatArray(seq * h)
for (t in 0 until seq) {
val gx = FloatArray(g) { k -> bIh[k] + (0 until d).sumOf { i -> (x[t * d + i] * wIh[i * g + k]).toDouble() }.toFloat() }
val gh = FloatArray(g) { k -> bHh[k] + (0 until h).sumOf { j -> (hidden[j] * wHh[j * g + k]).toDouble() }.toFloat() }
for (j in 0 until h) {
val r = sigmoid(gx[j] + gh[j])
val z = sigmoid(gx[h + j] + gh[h + j])
val n = tanh((gx[2 * h + j] + r * gh[2 * h + j]).toDouble()).toFloat()
hidden[j] = (1f - z) * n + z * hidden[j]
}
for (j in 0 until h) out[t * h + j] = hidden[j]
}
return out
}

@Test
fun gru_forward_matches_reference() {
val batch = 1; val seq = 2; val d = 2; val h = 2; val g = 3 * h
// deterministic small weights/inputs (kept in sigmoid/tanh's sensitive range)
val x = FloatArray(seq * d) { ((it % 5) - 2) * 0.3f }
val wIh = FloatArray(d * g) { ((it % 7) - 3) * 0.1f }
val wHh = FloatArray(h * g) { ((it % 5) - 2) * 0.15f }
val bIh = FloatArray(g) { ((it % 3) - 1) * 0.2f }
val bHh = FloatArray(g) { ((it % 4) - 2) * 0.1f }

val gru = Gru<FP32, Float>(
inputSize = d, hiddenSize = h, name = "gru",
initWeightIh = ctx.fromFloatArray(Shape(d, g), FP32::class, wIh),
initWeightHh = ctx.fromFloatArray(Shape(h, g), FP32::class, wHh),
initBiasIh = ctx.fromFloatArray(Shape(g), FP32::class, bIh),
initBiasHh = ctx.fromFloatArray(Shape(g), FP32::class, bHh),
)

val input = ctx.fromFloatArray<FP32, Float>(Shape(batch, seq, d), FP32::class, x)
val out = gru.forward(input, ctx)
assertEquals(Shape(batch, seq, h), out.shape)

val expected = gruRef(x, seq, d, h, wIh, wHh, bIh, bHh)
for (t in 0 until seq) for (j in 0 until h) {
assertEquals(expected[t * h + j], out.data[0, t, j], 1e-5f)
}
}

@Test
fun gru_output_shape_is_batch_seq_hidden() {
val batch = 2; val seq = 3; val d = 4; val h = 5; val g = 3 * h
val gru = Gru<FP32, Float>(
inputSize = d, hiddenSize = h, name = "gru",
initWeightIh = ctx.fromFloatArray(Shape(d, g), FP32::class, FloatArray(d * g) { (it % 9 - 4) * 0.05f }),
initWeightHh = ctx.fromFloatArray(Shape(h, g), FP32::class, FloatArray(h * g) { (it % 7 - 3) * 0.05f }),
initBiasIh = ctx.fromFloatArray(Shape(g), FP32::class, FloatArray(g) { 0f }),
initBiasHh = ctx.fromFloatArray(Shape(g), FP32::class, FloatArray(g) { 0f }),
)
val input = ctx.fromFloatArray<FP32, Float>(Shape(batch, seq, d), FP32::class, FloatArray(batch * seq * d) { (it % 11 - 5) * 0.1f })
val out = gru.forward(input, ctx)
assertEquals(Shape(batch, seq, h), out.shape)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.tensor.data.DenseTensorDataFactory
import sk.ainet.lang.tensor.data.FloatArrayTensorData
import sk.ainet.lang.nn.Gru
import sk.ainet.lang.tensor.ops.UpsampleMode
import sk.ainet.lang.tensor.withRequiresGrad
import sk.ainet.lang.trace.GraphSink
Expand Down Expand Up @@ -232,4 +233,28 @@ class ConvPoolBackwardTest {
x.ops.upsample2d(x, scale = 2 to 2, mode = UpsampleMode.Bilinear, alignCorners = false)
}
}

@Test
fun gru_backward_input_matches_finite_diff() {
// Single-layer GRU, input [batch=1, seq=2, in=2], hidden=2. The composed/unrolled
// cell must propagate gradients through every gate back to the input.
val d = 2; val h = 2; val g = 3 * h
val wIh = FloatArray(d * g) { ((it % 7) - 3) * 0.1f }
val wHh = FloatArray(h * g) { ((it % 5) - 2) * 0.1f }
val bIh = FloatArray(g) { ((it % 3) - 1) * 0.1f }
val bHh = FloatArray(g) { ((it % 4) - 2) * 0.05f }
assertGradMatchesFiniteDiff(
xShape = Shape(1, 2, d),
x0 = floatArrayOf(0.3f, -0.2f, 0.1f, 0.4f),
) { c, x ->
val gru = Gru<FP32, Float>(
inputSize = d, hiddenSize = h, name = "gru",
initWeightIh = floatTensor(c, Shape(d, g), wIh),
initWeightHh = floatTensor(c, Shape(h, g), wHh),
initBiasIh = floatTensor(c, Shape(g), bIh),
initBiasHh = floatTensor(c, Shape(g), bHh),
)
gru.forward(x, c)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package sk.ainet.lang.nn

import sk.ainet.context.ExecutionContext
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.types.DType
import sk.ainet.lang.nn.topology.ModuleParameter
import sk.ainet.lang.nn.topology.ModuleParameters

/**
* Single-layer, unidirectional, batch-first GRU (gated recurrent unit).
*
* Input `[batch, seq, inputSize]` -> output `[batch, seq, hiddenSize]` (all hidden states).
*
* The recurrence is **unrolled over the (static) sequence length at trace time** and built
* entirely from existing primitive ops (matmul / add / sigmoid / tanh / multiply / narrow /
* concat), so it runs in the eager engine, is trainable through the standard autodiff tape,
* and traces to StableHLO with no dedicated converter (StableHLO has no loop construct).
*
* Gate math matches `torch.nn.GRU` (gate order reset, update, new), so PyTorch weights load
* directly (after transposing to the matmul-ready orientation used here):
*
* r = sigmoid(x·W_ir + b_ir + h·W_hr + b_hr)
* z = sigmoid(x·W_iz + b_iz + h·W_hz + b_hz)
* n = tanh (x·W_in + b_in + r ⊙ (h·W_hn + b_hn))
* h' = (1 - z) ⊙ n + z ⊙ h
*
* Weights are stored **matmul-ready** (input-major) — the three gates are concatenated on the
* trailing axis so a single matmul produces all three pre-activations:
* - [weightIh] `[inputSize, 3*hiddenSize]`, [weightHh] `[hiddenSize, 3*hiddenSize]`
* - [biasIh] / [biasHh] `[3*hiddenSize]` (gate order r, z, n)
*
* @param inputSize number of input features
* @param hiddenSize size of the hidden state
*/
public class Gru<T : DType, V> @kotlin.jvm.JvmOverloads constructor(
public val inputSize: Int,
public val hiddenSize: Int,
override val name: String = "Gru",
initWeightIh: Tensor<T, V>,
initWeightHh: Tensor<T, V>,
initBiasIh: Tensor<T, V>,
initBiasHh: Tensor<T, V>,
public val trainable: Boolean = true,
) : Module<T, V>(), ModuleParameters<T, V> {

init {
require(inputSize > 0) { "Gru($name): inputSize must be positive, was $inputSize" }
require(hiddenSize > 0) { "Gru($name): hiddenSize must be positive, was $hiddenSize" }
val g = 3 * hiddenSize
fun check2d(t: Tensor<T, V>, rows: Int, cols: Int, what: String) {
val s = t.shape.dimensions
require(t.rank == 2 && s[0] == rows && s[1] == cols) {
"Gru($name): $what shape must be [$rows, $cols], but was ${t.shape}"
}
}
check2d(initWeightIh, inputSize, g, "weightIh")
check2d(initWeightHh, hiddenSize, g, "weightHh")
fun check1d(t: Tensor<T, V>, len: Int, what: String) {
require(t.rank == 1 && t.shape.dimensions[0] == len) {
"Gru($name): $what shape must be [$len], but was ${t.shape}"
}
}
check1d(initBiasIh, g, "biasIh")
check1d(initBiasHh, g, "biasHh")
}

private val pWeightIh = ModuleParameter.WeightParameter("$name.weight_ih", initWeightIh, trainable)
private val pWeightHh = ModuleParameter.WeightParameter("$name.weight_hh", initWeightHh, trainable)
private val pBiasIh = ModuleParameter.BiasParameter("$name.bias_ih", initBiasIh, trainable)
private val pBiasHh = ModuleParameter.BiasParameter("$name.bias_hh", initBiasHh, trainable)

override val params: List<ModuleParameter<T, V>> = listOf(pWeightIh, pWeightHh, pBiasIh, pBiasHh)

override val modules: List<Module<T, V>>
get() = emptyList()

override fun onForward(input: Tensor<T, V>, ctx: ExecutionContext): Tensor<T, V> {
require(input.rank == 3) {
"Gru($name): input must be 3D [batch, seq, inputSize], but was ${input.shape}"
}
val ops = ctx.ops
val batch = input.shape[0]
val seq = input.shape[1]
val h = hiddenSize

val weightIh = pWeightIh.value
val weightHh = pWeightHh.value
val biasIh = pBiasIh.value
val biasHh = pBiasHh.value

// Initial hidden state h0 = 0 (a constant leaf in the trace).
var hidden = ctx.zeros<T, V>(Shape(batch, h), input.dtype)

val outputs = ArrayList<Tensor<T, V>>(seq)
for (t in 0 until seq) {
// x_t : [batch, inputSize]
val xt = ops.reshape(ops.narrow(input, 1, t, 1), Shape(batch, inputSize))
// gate pre-activations: [batch, 3H]
val gx = ops.add(ops.matmul(xt, weightIh), biasIh)
val gh = ops.add(ops.matmul(hidden, weightHh), biasHh)
// split into reset / update / new on the gate axis
val xr = ops.narrow(gx, 1, 0, h); val hr = ops.narrow(gh, 1, 0, h)
val xz = ops.narrow(gx, 1, h, h); val hz = ops.narrow(gh, 1, h, h)
val xn = ops.narrow(gx, 1, 2 * h, h); val hn = ops.narrow(gh, 1, 2 * h, h)

val r = ops.sigmoid(ops.add(xr, hr))
val z = ops.sigmoid(ops.add(xz, hz))
val n = ops.tanh(ops.add(xn, ops.multiply(r, hn)))
// h' = (1 - z) ⊙ n + z ⊙ h
val oneMinusZ = ops.rsubScalar(1.0, z)
hidden = ops.add(ops.multiply(oneMinusZ, n), ops.multiply(z, hidden))

outputs.add(ops.unsqueeze(hidden, 1)) // [batch, 1, H]
}
return ops.concat(outputs, 1) // [batch, seq, H]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import sk.ainet.lang.nn.Conv1d
import sk.ainet.lang.nn.Conv2d
import sk.ainet.lang.nn.Conv3d
import sk.ainet.lang.nn.Flatten
import sk.ainet.lang.nn.Gru
import sk.ainet.lang.nn.Input
import sk.ainet.lang.nn.Linear
import sk.ainet.lang.nn.MaxPool2d
Expand Down Expand Up @@ -153,6 +154,17 @@ public interface NeuralNetworkDsl<T : DType, V> : NetworkDslItem {
*/
public fun dense(id: String = "", content: DENSE<T, V>.() -> Unit = {})

/**
* Creates a single-layer unidirectional GRU. Input `[batch, seq, features]` ->
* output `[batch, seq, hiddenSize]`. Input feature size is inferred from the
* preceding layer's output dimension.
*
* @param hiddenSize size of the GRU hidden state
* @param id Optional identifier for the layer
* @param content Configuration block (e.g. `trainable`)
*/
public fun gru(hiddenSize: Int, id: String = "", content: GRU<T, V>.() -> Unit = {})

/**
* Creates a dense layer with precision override and specified output dimension.
* This allows individual layers to use different precision than the network default.
Expand Down Expand Up @@ -598,6 +610,37 @@ public class FlattenImpl<T : DType, V>(
}
}

/** DSL config handle for a [Gru] layer. */
public interface GRU<T : DType, V> : NetworkDslItem {
public var trainable: Boolean
}

public class GruImpl<T : DType, V>(
override val executionContext: ExecutionContext,
private val inputSize: Int,
private val hiddenSize: Int,
private val id: String,
private val kClass: KClass<T>,
) : GRU<T, V> {
override var trainable: Boolean = true

public fun create(): Gru<T, V> {
require(inputSize > 0) { "Gru inputSize must be > 0 (declare an input shape before gru)." }
require(hiddenSize > 0) { "Gru hiddenSize must be > 0." }
val g = 3 * hiddenSize
return Gru(
inputSize = inputSize,
hiddenSize = hiddenSize,
name = id,
initWeightIh = executionContext.placeholder(Shape(inputSize, g), kClass),
initWeightHh = executionContext.placeholder(Shape(hiddenSize, g), kClass),
initBiasIh = executionContext.placeholder(Shape(g), kClass),
initBiasHh = executionContext.placeholder(Shape(g), kClass),
trainable = trainable,
)
}
}

private fun <T : DType, V> createLinear(
executionContext: ExecutionContext,
inFeatures: Int,
Expand Down Expand Up @@ -1237,6 +1280,21 @@ public class StageImpl<T : DType, V>(
modules += impl.create()
}

override fun gru(hiddenSize: Int, id: String, content: GRU<T, V>.() -> Unit) {
val inputSize = lastDimension
lastDimension = hiddenSize
currentShape = intArrayOf(hiddenSize)
val impl = GruImpl<T, V>(
executionContext,
inputSize = inputSize,
hiddenSize = hiddenSize,
id = getDefaultName(id, "gru", modules.size),
kClass = kClass,
)
impl.content()
modules += impl.create()
}

override fun activation(id: String, activation: (Tensor<T, V>) -> Tensor<T, V>) {
modules += ActivationsWrapperModule(activation, getDefaultName(id, "activation", modules.size))
}
Expand Down Expand Up @@ -1616,6 +1674,21 @@ public class NeuralNetworkDslImpl<T : DType, V>(
modules += impl.create()
}

override fun gru(hiddenSize: Int, id: String, content: GRU<T, V>.() -> Unit) {
val inputSize = lastDimension
lastDimension = hiddenSize
currentShape = intArrayOf(hiddenSize)
val impl = GruImpl<T, V>(
executionContext,
inputSize = inputSize,
hiddenSize = hiddenSize,
id = getDefaultName(id, "gru", modules.size),
kClass = kClass,
)
impl.content()
modules += impl.create()
}

override fun activation(id: String, activation: (Tensor<T, V>) -> Tensor<T, V>) {
modules += ActivationsWrapperModule(activation, getDefaultName(id, "activation", modules.size))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package sk.ainet.lang.nn

import sk.ainet.lang.types.FP32
import kotlin.test.Test
import kotlin.test.assertNotNull
import kotlin.test.assertTrue

/** The `gru(...)` network-DSL builder wires a [Gru] layer into the model. */
class GruDslTest {

@Test
fun gru_dsl_builds_a_gru_layer() {
val model = definition<FP32, Float> {
network {
input(4) // input feature size D
gru(8) // hidden size H
}
}
assertNotNull(model)
// The built model tree must contain a Gru module configured D=4 -> H=8.
val grus = flattenModules(model).filterIsInstance<Gru<FP32, Float>>()
assertTrue(grus.isNotEmpty(), "network { gru(8) } must produce a Gru module")
assertTrue(grus.any { it.inputSize == 4 && it.hiddenSize == 8 })
}

private fun flattenModules(m: Module<FP32, Float>): List<Module<FP32, Float>> =
listOf(m) + m.modules.flatMap { flattenModules(it) }
}
Loading