Skip to content

Add GRU layer (composed, trainable, StableHLO-exportable)#772

Merged
michalharakal merged 1 commit into
developfrom
feat/gru-module
Jun 28, 2026
Merged

Add GRU layer (composed, trainable, StableHLO-exportable)#772
michalharakal merged 1 commit into
developfrom
feat/gru-module

Conversation

@michalharakal

Copy link
Copy Markdown
Contributor

Closes #217.

What

The first recurrent layer in SKaiNET. GRU is implemented as a Module composing existing primitive ops, unrolled over the static sequence length at trace time.

Why composed rather than a fused op + dedicated converter: StableHLO export has no control flow (no while/scan), so a recurrence must unroll regardless of approach. A composed layer therefore emits the same unrolled primitive graph a fused convertGru would — but reuses every existing, already-validated converter (matmul/add/sigmoid/tanh/multiply/narrow/concat) with zero converter changes, and is trainable for free because every one of those primitives already has a correct autodiff backward.

Scope (first cut)

Single-layer, unidirectional, batch-first: [B, S, D] → [B, S, H]. PyTorch-compatible gate math and order (reset, update, new); weights stored matmul-ready so torch.nn.GRU weights load after a transpose:

r = sigmoid(x·Wir + b_ir + h·Whr + b_hr)
z = sigmoid(x·Wiz + b_iz + h·Whz + b_hz)
n = tanh   (x·Win + b_in + r ⊙ (h·Whn + b_hn))
h' = (1 - z) ⊙ n + z ⊙ h

Bidirectional / num_layers / packed sequences are explicit future work.

Changes

  • nn/Gru.kt — the layer (Module + ModuleParameters; params weight_ih/weight_hh [*,3H], bias_ih/bias_hh [3H]), unrolled forward.
  • nn/dsl/NetworkBuilder.ktGRU dsl item + GruImpl + gru(hiddenSize){} builder (input size inferred from the preceding layer).

Tests

  • GruTest — eager forward vs an independent scalar reference (all gates + recurrence + hidden feedback).
  • ConvPoolBackwardTest.gru_backward_input_matches_finite_diff — input gradient vs central finite difference (proves grads flow through the unrolled cell).
  • GruDslTestnetwork { gru(8) } wires a Gru module.

All three modules' suites pass locally (skainet-lang-core, skainet-backend-cpu, skainet-compile-dag).

Follow-up

End-to-end StableHLO → IREE export validation is added in the conformance suite (gated on a release carrying this layer), same flow as the norm/upsample ops.

🤖 Generated with Claude Code



First recurrent layer in SKaiNET. GRU is implemented as a Module composing
existing primitive ops (matmul/add/sigmoid/tanh/multiply/narrow/concat),
unrolled over the static sequence length at trace time — StableHLO has no loop
construct, so any recurrence must unroll regardless, and a composed layer reuses
every existing, already-validated converter (no new converter needed).

Single-layer, unidirectional, batch-first: [B,S,D] -> [B,S,H]. PyTorch-compatible
gate math/order (reset, update, new), weights stored matmul-ready:
  r = sigmoid(x·Wir + h·Whr + b);  z = sigmoid(x·Wiz + h·Whz + b)
  n = tanh(x·Win + r ⊙ (h·Whn) + b);  h' = (1-z) ⊙ n + z ⊙ h

Trainable for free: every primitive in the unrolled cell already has a correct
autodiff backward (sigmoid out*(1-out), tanh 1-out^2, matmul/add/sub/mul/
narrow/concat).

- nn/Gru.kt: the layer (Module + ModuleParameters, 4 params weight_ih/hh + bias_ih/hh).
- nn/dsl/NetworkBuilder.kt: GRU dsl item + GruImpl + gru(hiddenSize){} builder.
- Tests: eager forward vs independent reference (GruTest), input-grad finite-diff
  (ConvPoolBackwardTest.gru_backward_input_matches_finite_diff), DSL smoke (GruDslTest).

Bidirectional / num_layers / packed sequences are future work. End-to-end IREE
export is validated separately in the conformance suite.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@github-actions

Copy link
Copy Markdown

📖 Documentation Preview

The documentation has been built successfully for this PR.

Generated Files:

  • Operator documentation: docs/modules/operators/_generated_/
  • JSON schema output: operators.json

Artifacts:

  • Download the documentation-preview-772 artifact to view the complete documentation locally.

This comment will be updated automatically when the PR is updated.

@michalharakal michalharakal merged commit 5351684 into develop Jun 28, 2026
11 checks passed
@michalharakal michalharakal deleted the feat/gru-module branch June 28, 2026 19:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GRU

1 participant