diff --git a/build-logic/convention/src/main/resources/schemas/operator-doc-schema-v1.json b/build-logic/convention/src/main/resources/schemas/operator-doc-schema-v1.json index 53b50512..ca7f81fe 100644 --- a/build-logic/convention/src/main/resources/schemas/operator-doc-schema-v1.json +++ b/build-logic/convention/src/main/resources/schemas/operator-doc-schema-v1.json @@ -115,6 +115,14 @@ }, "description": "Array of notes associated with the function" }, + "isDifferentiable": { + "type": "boolean", + "description": "Whether the op carries @Diff, i.e. has a generated backward-rule contract and must be wired into the autodiff dispatch. Sourced from the @Diff annotation." + }, + "diffRuleName": { + "type": "string", + "description": "Custom adjoint rule name when @Diff(ruleName=...) is set; omitted for bare @Diff (rule name defaults to the op name)." + }, "validated": { "type": "boolean", "description": "Whether the function's documentation has been DARC-validated by a reviewer. Sourced from the @DarcValidated annotation on the function." diff --git a/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt b/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt index 1eeb6013..2a32e39e 100644 --- a/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt +++ b/skainet-compile/skainet-compile-core/src/commonMain/kotlin/sk/ainet/tape/RecordingExecution.kt @@ -465,6 +465,12 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor override fun indexSelect(input: Tensor, indices: Tensor, dim: Int): Tensor = base.indexSelect(input, indices, dim) override fun exp(tensor: Tensor): Tensor = base.exp(tensor) override fun expm1(tensor: Tensor): Tensor = base.expm1(tensor) + override fun sin(tensor: Tensor): Tensor = base.sin(tensor) + override fun cos(tensor: Tensor): Tensor = base.cos(tensor) + override fun convTranspose1d( + input: Tensor, weight: Tensor, bias: Tensor?, + stride: Int, padding: Int, outputPadding: Int, dilation: Int, groups: Int + ): Tensor = base.convTranspose1d(input, weight, bias, stride, padding, outputPadding, dilation, groups) override fun scaledDotProductAttention( query: Tensor, key: Tensor, value: Tensor, mask: Tensor?, scale: Float, causal: Boolean diff --git a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt index 72ed17a5..3c279d8f 100644 --- a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt +++ b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt @@ -673,9 +673,12 @@ public class DefaultGradientTape( override fun permuteBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { // Gradient of permute(t, axes) is permute(upstream, inverseAxes) - // where inverseAxes[axes[i]] = i. - val axes = (attributes["axes"] as? IntArray) - ?: error("permuteBackward: missing 'axes' attribute") + // where inverseAxes[axes[i]] = i. The trace records axes as a List (axes.toList()). + val axes = when (val a = attributes["axes"]) { + is IntArray -> a + is List<*> -> IntArray(a.size) { (a[it] as Number).toInt() } + else -> error("permuteBackward: missing 'axes' attribute") + } val inverse = IntArray(axes.size) for (i in axes.indices) inverse[axes[i]] = i return listOf(upstream.ops.permute(upstream, inverse)) @@ -1022,70 +1025,278 @@ public class DefaultGradientTape( return inputs.map { null } } + override fun sinBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { + // d(sin(x))/dx = cos(x) + val x = inputs[0] + return listOf(upstream.ops.multiply(upstream, x.ops.cos(x))) + } + + override fun cosBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { + // d(cos(x))/dx = -sin(x) + val x = inputs[0] + val negSin = x.ops.mulScalar(x.ops.sin(x), -1.0) + return listOf(upstream.ops.multiply(upstream, negSin)) + } + + override fun trilBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { + // tril zeroes the strict upper triangle; the gradient keeps the same lower-triangular region. + val k = (attributes["k"] as? Int) ?: 0 + return listOf(upstream.ops.tril(upstream, k)) + } + + override fun gatherBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { + // gather(input[vocab,emb], indices, dim=0) -> backward scatter-adds upstream rows into the gathered rows. + // Differentiable w.r.t. input only; indices gradient is null. + val input = inputs[0] + val indices = inputs[1] + val gradInput = zerosLike(input) + val numIndices = indices.volume + val indexList = IntArray(numIndices) { (indices.data[it] as Number).toInt() } + fun rowOf(outIdx: IntArray): Int { + val flatIdx = if (outIdx.size == 2) outIdx[0] else { + var flat = 0 + for (d in 0 until outIdx.size - 1) { + flat = flat * (if (d < indices.rank) indices.shape[d] else 1) + outIdx[d] + } + flat + } + return indexList[flatIdx] + } + val upDims = upstream.shape.dimensions + val outIdx = IntArray(upDims.size) + val srcIdx = IntArray(2) + fun walk(d: Int) { + if (d == upDims.size) { + srcIdx[0] = rowOf(outIdx) + srcIdx[1] = outIdx[upDims.size - 1] + accumulateAt(gradInput, srcIdx, upstream, outIdx) + return + } + for (i in 0 until upDims[d]) { outIdx[d] = i; walk(d + 1) } + } + walk(0) + return listOf(gradInput, null) + } + + override fun indexSelectBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { + // indexSelect(input, indices, dim) -> backward scatter-adds upstream slices into the selected positions. + val input = inputs[0] + val indices = inputs[1] + val dim = (attributes["dim"] as? Int) ?: 0 + val gradInput = zerosLike(input) + val numIndices = indices.volume + val indexList = IntArray(numIndices) { (indices.data[it] as Number).toInt() } + val upDims = upstream.shape.dimensions + val outIdx = IntArray(upDims.size) + val srcIdx = IntArray(upDims.size) + fun walk(d: Int) { + if (d == upDims.size) { + for (k in upDims.indices) srcIdx[k] = if (k == dim) indexList[outIdx[dim]] else outIdx[k] + accumulateAt(gradInput, srcIdx, upstream, outIdx) + return + } + for (i in 0 until upDims[d]) { outIdx[d] = i; walk(d + 1) } + } + walk(0) + return listOf(gradInput, null) + } + + override fun unfoldBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { + // unfold extracts overlapping windows; backward folds the window gradients back (overlapping-sum). + val input = inputs[0] + val rank = input.shape.rank + val dim = (attributes["dim"] as? Int) ?: 0 + val step = (attributes["step"] as? Int) ?: 1 + val actualDim = if (dim < 0) rank + dim else dim + val gradInput = zerosLike(input) + val upDims = upstream.shape.dimensions // rank + 1 + val outIdx = IntArray(upDims.size) + val srcIdx = IntArray(rank) + fun walk(d: Int) { + if (d == upDims.size) { + val windowIdx = outIdx[rank] + for (i in 0 until rank) srcIdx[i] = if (i == actualDim) outIdx[i] * step + windowIdx else outIdx[i] + accumulateAt(gradInput, srcIdx, upstream, outIdx) + return + } + for (i in 0 until upDims[d]) { outIdx[d] = i; walk(d + 1) } + } + walk(0) + return listOf(gradInput) + } + + override fun convTranspose1dBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { + // Adjoint of convTranspose1d forward: out[b,oc,ol] += in[b,ic,il]*w[ic,oc,k], ol = il*stride - padding + k*dilation. + val input = inputs[0] + val weight = inputs[1] + val hasBias = inputs.size > 2 + val stride = (attributes["stride"] as? Int) ?: 1 + val padding = (attributes["padding"] as? Int) ?: 0 + val dilation = (attributes["dilation"] as? Int) ?: 1 + val groups = (attributes["groups"] as? Int) ?: 1 + + val batch = input.shape[0] + val inChannels = input.shape[1] + val inLength = input.shape[2] + val outChannelsPerGroup = weight.shape[1] + val kernelSize = weight.shape[2] + val outLength = upstream.shape[2] + val inChPerGroup = inChannels / groups + + val gradInput = zerosLike(input) + val gradWeight = zerosLike(weight) + val gIdx = IntArray(3) + val wIdx = IntArray(3) + val uIdx = IntArray(3) + + for (b in 0 until batch) { + for (g in 0 until groups) { + for (ic in 0 until inChPerGroup) { + val inCh = g * inChPerGroup + ic + for (oc in 0 until outChannelsPerGroup) { + val outCh = g * outChannelsPerGroup + oc + for (il in 0 until inLength) { + for (k in 0 until kernelSize) { + val ol = il * stride - padding + k * dilation + if (ol < 0 || ol >= outLength) continue + uIdx[0] = b; uIdx[1] = outCh; uIdx[2] = ol + val up = (upstream.data.get(*uIdx) as Number).toFloat() + wIdx[0] = inCh; wIdx[1] = oc; wIdx[2] = k + val w = (weight.data.get(*wIdx) as Number).toFloat() + gIdx[0] = b; gIdx[1] = inCh; gIdx[2] = il + val xVal = (input.data.get(*gIdx) as Number).toFloat() + accumulateScalar(gradInput, gIdx, up * w) + accumulateScalar(gradWeight, wIdx, xVal * up) + } + } + } + } + } + } + + val grads = mutableListOf?>(gradInput, gradWeight) + if (hasBias) { + val bias = inputs[2] + val gradBias = zerosLike(bias) + val bIdx = IntArray(1) + val outChannels = outChannelsPerGroup * groups + for (oc in 0 until outChannels) { + var acc = 0f + for (b in 0 until batch) { + for (ol in 0 until outLength) { + uIdx[0] = b; uIdx[1] = oc; uIdx[2] = ol + acc += (upstream.data.get(*uIdx) as Number).toFloat() + } + } + bIdx[0] = oc + accumulateScalar(gradBias, bIdx, acc) + } + grads.add(gradBias) + } + return grads + } + + /** Read-add-write [src]'s value at [srcIdx] into [dest] at [destIdx]; used by scatter-add backwards. */ + private fun accumulateAt(dest: Tensor, destIdx: IntArray, src: Tensor, srcIdx: IntArray) { + accumulateScalar(dest, destIdx, (src.data.get(*srcIdx) as Number).toFloat()) + } + + /** Add [delta] into [dest] at [idx]. */ + private fun accumulateScalar(dest: Tensor, idx: IntArray, delta: Float) { + val cur = (dest.data.get(*idx) as Number).toFloat() + dest.data.set(*idx, value = (cur + delta)) + } + + /** + * Trace-op-name → adjoint rule. Backward funcs all share the + * `(upstream, output, inputs, attributes) -> List` signature, so each arm is a member + * reference. This table is the single source of dispatch truth; [dispatchedOpNames] exposes its + * keys to the autodiff-coverage guard (AutodiffCoverageTest) which asserts every `@Diff` op + * (the KSP-generated DifferentiableTensorOpsRules.ruleNames) is present here — preventing the + * silent "implemented-but-not-wired" drift that previously dropped elu/leakyRelu/permute grads. + */ + private val backwardDispatch: + Map, Tensor, List>, Map) -> List?>> = + mapOf( + "add" to ::addBackward, + "addScalar" to ::addScalarBackward, + "subtract" to ::subtractBackward, + "subScalar" to ::subScalarBackward, + "rsubScalar" to ::rsubScalarBackward, + "multiply" to ::multiplyBackward, + "mulScalar" to ::mulScalarBackward, + "divide" to ::divideBackward, + "divScalar" to ::divScalarBackward, + "rdivScalar" to ::rdivScalarBackward, + "matmul" to ::matmulBackward, + "transpose" to ::transposeBackward, + "permute" to ::permuteBackward, + "relu" to ::reluBackward, + "leakyRelu" to ::leakyReluBackward, + "elu" to ::eluBackward, + "sum" to ::sumBackward, + "mean" to ::meanBackward, + "softmax" to ::softmaxBackward, + "logSoftmax" to ::logSoftmaxBackward, + "reshape" to ::reshapeBackward, + "flatten" to ::flattenBackward, + "squeeze" to ::squeezeBackward, + "unsqueeze" to ::unsqueezeBackward, + "sigmoid" to ::sigmoidBackward, + "tanh" to ::tanhBackward, + "silu" to ::siluBackward, + "gelu" to ::geluBackward, + "variance" to ::varianceBackward, + "sqrt" to ::sqrtBackward, + "pow" to ::powBackward, + "powScalar" to ::powScalarBackward, + "log" to ::logBackward, + "log2" to ::log2Backward, + "log10" to ::log10Backward, + "abs" to ::absBackward, + "clamp" to ::clampBackward, + "narrow" to ::narrowBackward, + "pad2d" to ::pad2dBackward, + "conv1d" to ::conv1dBackward, + "conv2d" to ::conv2dBackward, + "conv3d" to ::conv3dBackward, + "convTranspose1d" to ::convTranspose1dBackward, + "maxPool2d" to ::maxPool2dBackward, + "avgPool2d" to ::avgPool2dBackward, + "upsample2d" to ::upsample2dBackward, + "concat" to ::concatBackward, + "split" to ::splitBackward, + "exp" to ::expBackward, + "expm1" to ::expm1Backward, + "sin" to ::sinBackward, + "cos" to ::cosBackward, + "tril" to ::trilBackward, + "gather" to ::gatherBackward, + "indexSelect" to ::indexSelectBackward, + "unfold" to ::unfoldBackward, + "scaledDotProductAttention" to ::scaledDotProductAttentionBackward, + ) + + /** Trace op names with a wired backward rule. Consumed by the autodiff-coverage guard test. */ + internal val dispatchedOpNames: Set get() = backwardDispatch.keys + private fun buildBackwardFromTrace( trace: OpTrace, inputs: List>, output: Tensor ): BackwardOp? { - return when (trace.opType) { - "add" -> BackwardOp(inputs, output) { upstream -> addBackward(upstream, output, inputs, trace.attributes) } - "addScalar" -> BackwardOp(inputs, output) { upstream -> addScalarBackward(upstream, output, inputs, trace.attributes) } - "subtract" -> BackwardOp(inputs, output) { upstream -> subtractBackward(upstream, output, inputs, trace.attributes) } - "subScalar" -> BackwardOp(inputs, output) { upstream -> subScalarBackward(upstream, output, inputs, trace.attributes) } - "rsubScalar" -> BackwardOp(inputs, output) { upstream -> rsubScalarBackward(upstream, output, inputs, trace.attributes) } - "multiply" -> BackwardOp(inputs, output) { upstream -> multiplyBackward(upstream, output, inputs, trace.attributes) } - "mulScalar" -> BackwardOp(inputs, output) { upstream -> mulScalarBackward(upstream, output, inputs, trace.attributes) } - "divide" -> BackwardOp(inputs, output) { upstream -> divideBackward(upstream, output, inputs, trace.attributes) } - "divScalar" -> BackwardOp(inputs, output) { upstream -> divScalarBackward(upstream, output, inputs, trace.attributes) } - "rdivScalar" -> BackwardOp(inputs, output) { upstream -> rdivScalarBackward(upstream, output, inputs, trace.attributes) } - "matmul" -> BackwardOp(inputs, output) { upstream -> matmulBackward(upstream, output, inputs, trace.attributes) } - "transpose" -> BackwardOp(inputs, output) { upstream -> transposeBackward(upstream, output, inputs, trace.attributes) } - "relu" -> BackwardOp(inputs, output) { upstream -> reluBackward(upstream, output, inputs, trace.attributes) } - "sum" -> BackwardOp(inputs, output) { upstream -> sumBackward(upstream, output, inputs, trace.attributes) } - "mean" -> BackwardOp(inputs, output) { upstream -> meanBackward(upstream, output, inputs, trace.attributes) } - "softmax" -> BackwardOp(inputs, output) { upstream -> softmaxBackward(upstream, output, inputs, trace.attributes) } - "logSoftmax" -> BackwardOp(inputs, output) { upstream -> logSoftmaxBackward(upstream, output, inputs, trace.attributes) } - "reshape" -> BackwardOp(inputs, output) { upstream -> reshapeBackward(upstream, output, inputs, trace.attributes) } - "flatten" -> BackwardOp(inputs, output) { upstream -> flattenBackward(upstream, output, inputs, trace.attributes) } - "squeeze" -> BackwardOp(inputs, output) { upstream -> squeezeBackward(upstream, output, inputs, trace.attributes) } - "unsqueeze" -> BackwardOp(inputs, output) { upstream -> unsqueezeBackward(upstream, output, inputs, trace.attributes) } - "sigmoid" -> BackwardOp(inputs, output) { upstream -> sigmoidBackward(upstream, output, inputs, trace.attributes) } - "tanh" -> BackwardOp(inputs, output) { upstream -> tanhBackward(upstream, output, inputs, trace.attributes) } - "silu" -> BackwardOp(inputs, output) { upstream -> siluBackward(upstream, output, inputs, trace.attributes) } - "gelu" -> BackwardOp(inputs, output) { upstream -> geluBackward(upstream, output, inputs, trace.attributes) } - "variance" -> BackwardOp(inputs, output) { upstream -> varianceBackward(upstream, output, inputs, trace.attributes) } - "sqrt" -> BackwardOp(inputs, output) { upstream -> sqrtBackward(upstream, output, inputs, trace.attributes) } - "pow" -> BackwardOp(inputs, output) { upstream -> powBackward(upstream, output, inputs, trace.attributes) } - "powScalar" -> BackwardOp(inputs, output) { upstream -> powScalarBackward(upstream, output, inputs, trace.attributes) } - "log" -> BackwardOp(inputs, output) { upstream -> logBackward(upstream, output, inputs, trace.attributes) } - "log2" -> BackwardOp(inputs, output) { upstream -> log2Backward(upstream, output, inputs, trace.attributes) } - "log10" -> BackwardOp(inputs, output) { upstream -> log10Backward(upstream, output, inputs, trace.attributes) } - "abs" -> BackwardOp(inputs, output) { upstream -> absBackward(upstream, output, inputs, trace.attributes) } - "clamp" -> BackwardOp(inputs, output) { upstream -> clampBackward(upstream, output, inputs, trace.attributes) } - "narrow" -> BackwardOp(inputs, output) { upstream -> narrowBackward(upstream, output, inputs, trace.attributes) } - "pad2d" -> BackwardOp(inputs, output) { upstream -> pad2dBackward(upstream, output, inputs, trace.attributes) } - "conv1d" -> BackwardOp(inputs, output) { upstream -> conv1dBackward(upstream, output, inputs, trace.attributes) } - "conv2d" -> BackwardOp(inputs, output) { upstream -> conv2dBackward(upstream, output, inputs, trace.attributes) } - "conv3d" -> BackwardOp(inputs, output) { upstream -> conv3dBackward(upstream, output, inputs, trace.attributes) } - "maxPool2d" -> BackwardOp(inputs, output) { upstream -> maxPool2dBackward(upstream, output, inputs, trace.attributes) } - "avgPool2d" -> BackwardOp(inputs, output) { upstream -> avgPool2dBackward(upstream, output, inputs, trace.attributes) } - "upsample2d" -> BackwardOp(inputs, output) { upstream -> upsample2dBackward(upstream, output, inputs, trace.attributes) } - "concat" -> BackwardOp(inputs, output) { upstream -> concatBackward(upstream, output, inputs, trace.attributes) } - "split" -> BackwardOp(inputs, output) { upstream -> splitBackward(upstream, output, inputs, trace.attributes) } - "exp" -> BackwardOp(inputs, output) { upstream -> expBackward(upstream, output, inputs, trace.attributes) } - "expm1" -> BackwardOp(inputs, output) { upstream -> expm1Backward(upstream, output, inputs, trace.attributes) } - "scaledDotProductAttention" -> BackwardOp(inputs, output) { upstream -> scaledDotProductAttentionBackward(upstream, output, inputs, trace.attributes) } - else -> { - // Support custom backward functions passed via trace attributes - @Suppress("UNCHECKED_CAST") - val customBackward = trace.attributes["_backwardFn"] - as? (Tensor, Tensor, List>, Map) -> List?> - if (customBackward != null) { - BackwardOp(inputs, output) { upstream -> customBackward(upstream, output, inputs, trace.attributes) } - } else { - null - } - } + val rule = backwardDispatch[trace.opType] + if (rule != null) { + return BackwardOp(inputs, output) { upstream -> rule(upstream, output, inputs, trace.attributes) } + } + // Support custom backward functions passed via trace attributes + @Suppress("UNCHECKED_CAST") + val customBackward = trace.attributes["_backwardFn"] + as? (Tensor, Tensor, List>, Map) -> List?> + return if (customBackward != null) { + BackwardOp(inputs, output) { upstream -> customBackward(upstream, output, inputs, trace.attributes) } + } else { + null } } diff --git a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/ComputeGraphExecutorTest.kt b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/ComputeGraphExecutorTest.kt index 3fb1c852..b5ea2d13 100644 --- a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/ComputeGraphExecutorTest.kt +++ b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/compile/graph/ComputeGraphExecutorTest.kt @@ -203,6 +203,9 @@ private class TestTensorOps : TensorOps { override fun indexSelect(input: Tensor, indices: Tensor, dim: Int): Tensor = input override fun exp(tensor: Tensor): Tensor = tensor override fun expm1(tensor: Tensor): Tensor = tensor + override fun sin(tensor: Tensor): Tensor = tensor + override fun cos(tensor: Tensor): Tensor = tensor + override fun convTranspose1d(input: Tensor, weight: Tensor, bias: Tensor?, stride: Int, padding: Int, outputPadding: Int, dilation: Int, groups: Int): Tensor = input override fun scaledDotProductAttention(query: Tensor, key: Tensor, value: Tensor, mask: Tensor?, scale: Float, causal: Boolean): Tensor = query override fun conv1d(input: Tensor, weight: Tensor, bias: Tensor?, stride: Int, padding: Int, dilation: Int, groups: Int): Tensor = input override fun conv2d(input: Tensor, weight: Tensor, bias: Tensor?, stride: Pair, padding: Pair, dilation: Pair, groups: Int): Tensor = input diff --git a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/AutodiffCoverageTest.kt b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/AutodiffCoverageTest.kt new file mode 100644 index 00000000..93b93733 --- /dev/null +++ b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/AutodiffCoverageTest.kt @@ -0,0 +1,31 @@ +package sk.ainet.exec.autograd + +import sk.ainet.lang.graph.DefaultGradientTape +import sk.ainet.lang.tensor.ops.DifferentiableTensorOpsRules +import kotlin.test.Test +import kotlin.test.assertTrue + +/** + * Autodiff-coverage guard: every op marked `@Diff` (the KSP-generated + * [DifferentiableTensorOpsRules.ruleNames]) must have a wired backward dispatch arm in + * [DefaultGradientTape] (its `dispatchedOpNames`). + * + * The `@Diff` → generated `DifferentiableTensorOps` interface already forces a backward *formula* + * to exist (compile error otherwise). This test closes the remaining link: that the formula is + * actually *reachable* from the trace dispatch. It would have caught the historical bug where + * `elu`/`leakyRelu`/`permute` had correct backward formulas that were never wired into the + * dispatch, so their gradients were silently dropped. + */ +class AutodiffCoverageTest { + + @Test + fun every_diff_op_has_a_wired_backward_dispatch() { + val dispatched = DefaultGradientTape().dispatchedOpNames + val missing = DifferentiableTensorOpsRules.ruleNames - dispatched + assertTrue( + missing.isEmpty(), + "These @Diff ops have a generated backward contract but no dispatch arm in " + + "DefaultExecutionTape.backwardDispatch (their gradients would silently drop to null): $missing", + ) + } +} diff --git a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/OpsAutodiffBackwardTest.kt b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/OpsAutodiffBackwardTest.kt new file mode 100644 index 00000000..57832380 --- /dev/null +++ b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/OpsAutodiffBackwardTest.kt @@ -0,0 +1,177 @@ +package sk.ainet.exec.autograd + +import kotlin.math.abs +import kotlin.test.Test +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import sk.ainet.context.Phase +import sk.ainet.exec.tensor.ops.DefaultCpuOps +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.DefaultGradientTape +import sk.ainet.lang.graph.DefaultGraphExecutionContext +import sk.ainet.lang.tensor.Shape +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.data.DenseTensorDataFactory +import sk.ainet.lang.tensor.data.FloatArrayTensorData +import sk.ainet.lang.tensor.withRequiresGrad +import sk.ainet.lang.trace.GraphSink +import sk.ainet.lang.types.DType +import sk.ainet.lang.types.FP32 +import sk.ainet.lang.types.Int32 + +/** + * Finite-difference backward parity for the autodiff-gap-closing work: the three previously + * implemented-but-unwired activations (elu, leakyRelu, permute) and the newly-differentiable ops + * (cos, sin, tril, gather, indexSelect, unfold, convTranspose1d). Each analytic gradient is compared + * to central finite difference of a sum-reduced output. Tolerance is generous (FP32 noise); + * correctness, not precision. + */ +class OpsAutodiffBackwardTest { + + private fun ctx(): DefaultGraphExecutionContext { + val dataFactory = DenseTensorDataFactory() + val cpuOps = DefaultCpuOps(dataFactory) + val graph = DefaultComputeGraph() + return DefaultGraphExecutionContext( + baseOps = cpuOps, + phase = Phase.TRAIN, + tensorDataFactory = dataFactory, + createTapeFactory = { _ -> DefaultGradientTape(true) }, + computeGraph = graph, + baseSink = GraphSink(graph), + ) + } + + private fun floatTensor(c: DefaultGraphExecutionContext, shape: Shape, values: FloatArray): Tensor = + c.fromFloatArray(shape, FP32::class, values) + + private fun intTensor(c: DefaultGraphExecutionContext, shape: Shape, values: IntArray): Tensor { + @Suppress("UNCHECKED_CAST") + return c.fromIntArray(shape, Int32::class, values) as Tensor + } + + private fun buf(t: Tensor<*, *>): FloatArray = (t.data as FloatArrayTensorData<*>).buffer + + private fun FloatArray.sumElems(): Float { + var s = 0f + for (v in this) s += v + return s + } + + private fun assertGradMatchesFiniteDiff( + xShape: Shape, + x0: FloatArray, + eps: Float = 1e-3f, + tol: Float = 3e-2f, + f: (DefaultGraphExecutionContext, Tensor) -> Tensor, + ) { + val ctx = ctx() + val x = floatTensor(ctx, xShape, x0.copyOf()).withRequiresGrad() + val pair = ctx.record { + val out = f(this, x) + out.ops.sum(out) + } + val sumOutput = pair.second + val tape = pair.first as DefaultGradientTape + tape.computeGradients(targets = listOf(sumOutput), sources = listOf(x)) + val analyticGrad = x.grad + assertNotNull(analyticGrad, "tape should populate x.grad") + val analytic = buf(analyticGrad) + + for (i in x0.indices) { + val xPlus = x0.copyOf().also { it[i] += eps } + val xMinus = x0.copyOf().also { it[i] -= eps } + val ctxPlus = ctx() + val ctxMinus = ctx() + val fPlus = buf(f(ctxPlus, floatTensor(ctxPlus, xShape, xPlus))).sumElems() + val fMinus = buf(f(ctxMinus, floatTensor(ctxMinus, xShape, xMinus))).sumElems() + val fdGrad = (fPlus - fMinus) / (2 * eps) + val diff = abs(analytic[i] - fdGrad) + assertTrue(diff <= tol, "[$i] analytic=${analytic[i]} fd=$fdGrad diff=$diff tol=$tol") + } + } + + // ── previously implemented-but-unwired (the silent-grad bug) ────────────── + + @Test + fun elu_backward_matches_finite_diff() { + assertGradMatchesFiniteDiff(Shape(6), floatArrayOf(-1.5f, -0.4f, 0.3f, 0.9f, -0.7f, 1.2f), tol = 1e-2f) { _, x -> + x.ops.elu(x, alpha = 1.0f) + } + } + + @Test + fun leakyRelu_backward_matches_finite_diff() { + assertGradMatchesFiniteDiff(Shape(6), floatArrayOf(-1.5f, -0.4f, 0.3f, 0.9f, -0.7f, 1.2f), tol = 1e-2f) { _, x -> + x.ops.leakyRelu(x, negativeSlope = 0.1f) + } + } + + @Test + fun permute_backward_routes_axes_inverse() { + // sum(w ⊙ permute(x)) — the constant weight makes the upstream non-uniform, so a wrong + // inverse-axes would fail (a plain sum(permute(x)) has all-ones grad and can't detect it). + assertGradMatchesFiniteDiff(Shape(2, 3), FloatArray(6) { (it - 2) * 0.3f }, tol = 1e-2f) { c, x -> + val p = x.ops.permute(x, intArrayOf(1, 0)) // [2,3] -> [3,2] + val w = floatTensor(c, Shape(3, 2), floatArrayOf(1f, 2f, 3f, 4f, 5f, 6f)) + x.ops.multiply(p, w) + } + } + + // ── trivial new diffs ───────────────────────────────────────────────────── + + @Test + fun sin_backward_matches_finite_diff() { + assertGradMatchesFiniteDiff(Shape(5), floatArrayOf(-1f, -0.3f, 0.2f, 0.8f, 1.4f), tol = 1e-2f) { _, x -> x.ops.sin(x) } + } + + @Test + fun cos_backward_matches_finite_diff() { + assertGradMatchesFiniteDiff(Shape(5), floatArrayOf(-1f, -0.3f, 0.2f, 0.8f, 1.4f), tol = 1e-2f) { _, x -> x.ops.cos(x) } + } + + @Test + fun tril_backward_masks_upper_triangle() { + // grad of sum(tril(x)) is the lower-triangular mask — position-dependent, so a no-op + // backward (passing the full upstream) would fail. + assertGradMatchesFiniteDiff(Shape(3, 3), FloatArray(9) { (it - 4) * 0.2f }, tol = 1e-2f) { _, x -> x.ops.tril(x, 0) } + } + + // ── structural new diffs (scatter-add / fold / conv-transpose) ───────────── + + @Test + fun gather_backward_scatter_adds_rows() { + // table [vocab=4, emb=3], indices [0,2,2,1] -> row gradients = gather counts (1,1,2,0). + assertGradMatchesFiniteDiff(Shape(4, 3), FloatArray(12) { (it - 6) * 0.1f }) { c, x -> + val idx = intTensor(c, Shape(4), intArrayOf(0, 2, 2, 1)) + x.ops.gather(x, idx, dim = 0) + } + } + + @Test + fun indexSelect_backward_scatter_adds_along_dim() { + // x [3,4], dim=1, indices [0,2,2] -> col gradients = select counts (1,0,2,0). + assertGradMatchesFiniteDiff(Shape(3, 4), FloatArray(12) { (it - 6) * 0.1f }) { c, x -> + val idx = intTensor(c, Shape(3), intArrayOf(0, 2, 2)) + x.ops.indexSelect(x, idx, dim = 1) + } + } + + @Test + fun unfold_backward_folds_overlapping_windows() { + // x [6], size 3, step 1 -> 4 windows; each element's grad = number of windows covering it. + assertGradMatchesFiniteDiff(Shape(6), FloatArray(6) { (it - 3) * 0.25f }) { _, x -> + x.ops.unfold(x, dim = 0, size = 3, step = 1) + } + } + + @Test + fun convTranspose1d_backward_matches_finite_diff() { + // input [1,1,3], weight [in=1, outPerGroup=1, k=2], stride 1, padding 0. + val w = floatArrayOf(0.5f, -1.2f) + assertGradMatchesFiniteDiff(Shape(1, 1, 3), floatArrayOf(0.3f, -0.2f, 0.7f)) { c, x -> + val wT = floatTensor(c, Shape(1, 1, 2), w) + x.ops.convTranspose1d(x, wT, null, stride = 1, padding = 0, outputPadding = 0, dilation = 1, groups = 1) + } + } +} diff --git a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt index e55afeaa..d326cd2a 100644 --- a/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt +++ b/skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/tensor/ops/TensorOps.kt @@ -108,6 +108,7 @@ public interface TensorOps { ): Tensor // Transposed convolutional operations + @Diff public fun convTranspose1d( input: Tensor, weight: Tensor, @@ -117,9 +118,7 @@ public interface TensorOps { outputPadding: Int = 0, dilation: Int = 1, groups: Int = 1 - ): Tensor { - throw NotImplementedError("convTranspose1d not implemented by this TensorOps backend") - } + ): Tensor // Pooling operations @Diff @@ -251,20 +250,22 @@ public interface TensorOps { /** Extract sliding windows of size [size] along dimension [dim] with stride [step]. * Result has one extra dimension appended containing the window elements. */ + @Diff public fun unfold(tensor: Tensor, dim: Int, size: Int, step: Int): Tensor // Comparison operations (return mask tensors with 1.0 where true, 0.0 where false) - /** Element-wise less than: x < value → 1.0, else 0.0 */ + /** Element-wise less than: x < value → 1.0, else 0.0. Non-differentiable by design (boolean mask). */ public fun lt(tensor: Tensor, value: Float): Tensor - /** Element-wise greater than or equal: x >= value → 1.0, else 0.0 */ + /** Element-wise greater than or equal: x >= value → 1.0, else 0.0. Non-differentiable by design (boolean mask). */ public fun ge(tensor: Tensor, value: Float): Tensor // Matrix utilities + @Diff public fun tril(tensor: Tensor, k: Int = 0): Tensor - // Type conversion operations + // Type conversion operations. Non-differentiable by design (dtype cast). public fun convert( tensor: Tensor, targetType: TTo @@ -273,7 +274,9 @@ public interface TensorOps { // --- LLM / Transformer primitives --- /** Gather rows from [input] along [dim] using integer [indices]. - * Primary use: embedding lookup (dim=0, indices=token IDs). */ + * Primary use: embedding lookup (dim=0, indices=token IDs). + * Differentiable w.r.t. [input] only (scatter-add); [indices] are discrete. */ + @Diff public fun gather( input: Tensor, indices: Tensor, @@ -281,7 +284,9 @@ public interface TensorOps { ): Tensor /** Select elements from [input] along [dim] at the given [indices]. - * Similar to gather but for general index selection patterns. */ + * Similar to gather but for general index selection patterns. + * Differentiable w.r.t. [input] only (scatter-add); [indices] are discrete. */ + @Diff public fun indexSelect( input: Tensor, indices: Tensor, @@ -299,13 +304,11 @@ public interface TensorOps { public fun expm1(tensor: Tensor): Tensor // Trigonometric operations - public fun sin(tensor: Tensor): Tensor { - throw NotImplementedError("sin not implemented by this TensorOps backend") - } + @Diff + public fun sin(tensor: Tensor): Tensor - public fun cos(tensor: Tensor): Tensor { - throw NotImplementedError("cos not implemented by this TensorOps backend") - } + @Diff + public fun cos(tensor: Tensor): Tensor /** * Scaled dot-product attention. diff --git a/skainet-lang/skainet-lang-ksp-processor/src/main/kotlin/sk/ainet/lang/ksp/TracingWrapperProcessor.kt b/skainet-lang/skainet-lang-ksp-processor/src/main/kotlin/sk/ainet/lang/ksp/TracingWrapperProcessor.kt index ce2432f4..0f4306f1 100644 --- a/skainet-lang/skainet-lang-ksp-processor/src/main/kotlin/sk/ainet/lang/ksp/TracingWrapperProcessor.kt +++ b/skainet-lang/skainet-lang-ksp-processor/src/main/kotlin/sk/ainet/lang/ksp/TracingWrapperProcessor.kt @@ -57,6 +57,7 @@ class TracingWrapperGenerator( val differentiableMethods = methods.filter { it.isDifferentiable } if (differentiableMethods.isNotEmpty()) { generateDifferentiableOps(interfaceDeclaration, differentiableMethods) + generateDifferentiableOpsRules(interfaceDeclaration, differentiableMethods) } logger.info("Successfully generated $generatedClassName.kt in package $packageName") @@ -117,7 +118,47 @@ class TracingWrapperGenerator( outputStream.write(code.toByteArray()) } } - + + /** + * Generates the authoritative set of differentiable op rule-names (one per @Diff method). + * This is the machine-readable companion to the [generateDifferentiableOps] interface: the + * interface forces every @Diff op to have a backward *formula* (compile error otherwise), while + * this set lets a downstream test assert every @Diff op is also *dispatched* in the execution + * tape — closing the contract ⟷ formula ⟷ wiring loop so none can silently drift. + */ + private fun generateDifferentiableOpsRules( + interfaceDeclaration: KSClassDeclaration, + methods: List + ) { + val packageName = interfaceDeclaration.packageName.asString() + val objectName = "Differentiable${interfaceDeclaration.simpleName.asString()}Rules" + val ruleNames = methods.map { it.diffRuleName ?: it.name }.distinct().sorted() + + val file = codeGenerator.createNewFile( + dependencies = Dependencies(false, interfaceDeclaration.containingFile!!), + packageName = packageName, + fileName = objectName + ) + + file.use { outputStream -> + val code = buildString { + appendLine("package $packageName") + appendLine() + appendLine("/**") + appendLine(" * Authoritative set of differentiable op rule-names, generated from @Diff annotations on") + appendLine(" * ${interfaceDeclaration.simpleName.asString()}. Used by the autodiff-coverage guard to verify every") + appendLine(" * differentiable op has a wired backward rule. Do not edit by hand.") + appendLine(" */") + appendLine("public object $objectName {") + appendLine(" public val ruleNames: Set = setOf(") + ruleNames.forEach { appendLine(" \"$it\",") } + appendLine(" )") + appendLine("}") + } + outputStream.write(code.toByteArray()) + } + } + /** * Validates that code generation is possible for all methods. */ @@ -294,7 +335,7 @@ class TracingWrapperGenerator( // Fix the type for nullable parameters that should have ? suffix val correctedType = when { param.name == "dim" && method.name in listOf("squeeze", "sum", "mean", "variance") && !param.type.endsWith("?") -> "${param.type}?" - param.name == "bias" && method.name in listOf("conv1d", "conv2d", "conv3d") && !param.type.endsWith("?") -> "${param.type}?" + param.name == "bias" && method.name in listOf("conv1d", "conv2d", "conv3d", "convTranspose1d") && !param.type.endsWith("?") -> "${param.type}?" param.name == "mask" && method.name == "scaledDotProductAttention" && !param.type.endsWith("?") -> "${param.type}?" else -> param.type } diff --git a/skainet-lang/skainet-lang-ksp-processor/src/main/kotlin/sk/ainet/lang/ops/ksp/OperatorDocProcessor.kt b/skainet-lang/skainet-lang-ksp-processor/src/main/kotlin/sk/ainet/lang/ops/ksp/OperatorDocProcessor.kt index 6472c956..33206760 100644 --- a/skainet-lang/skainet-lang-ksp-processor/src/main/kotlin/sk/ainet/lang/ops/ksp/OperatorDocProcessor.kt +++ b/skainet-lang/skainet-lang-ksp-processor/src/main/kotlin/sk/ainet/lang/ops/ksp/OperatorDocProcessor.kt @@ -31,6 +31,11 @@ data class FunctionDoc( val returnType: String, val statusByBackend: Map, val notes: List, + // Autodiff coverage: whether the op carries @Diff (so it has a generated backward + // contract and must be wired into the execution tape's dispatch). diffRuleName is the + // custom adjoint rule name if @Diff(ruleName=...) was set, else empty (= method name). + val isDifferentiable: Boolean = false, + val diffRuleName: String = "", // DARC validation metadata. `validated = false` means @DarcValidated is // absent — the generator will render a "not validated" badge. val validated: Boolean = false, @@ -195,6 +200,8 @@ class OperatorDocProcessor( returnType = extractReturnType(fn), statusByBackend = statusByBackend, notes = emptyList(), + isDifferentiable = extractDiffRule(fn) != null, + diffRuleName = extractDiffRule(fn).orEmpty(), validated = validation.validated, validatedBy = validation.by, validatedOn = validation.on, @@ -281,6 +288,8 @@ class OperatorDocProcessor( returnType = extractReturnType(function), statusByBackend = deriveStatusByBackend(function), notes = deriveNotes(function), + isDifferentiable = extractDiffRule(function) != null, + diffRuleName = extractDiffRule(function).orEmpty(), validated = validation.validated, validatedBy = validation.by, validatedOn = validation.on, @@ -289,6 +298,20 @@ class OperatorDocProcessor( ) } + /** + * Returns the adjoint rule name when the function carries `@Diff` (its `ruleName` argument, or + * the empty string for bare `@Diff`), or `null` when the op is not differentiable. Mirrors the + * detection in the autodiff KSP processor (`MethodAnalyzer`) so operators.json agrees with the + * generated `DifferentiableTensorOps` contract. + */ + private fun extractDiffRule(function: KSFunctionDeclaration): String? { + val diff = function.annotations.find { + it.shortName.asString() == "Diff" || + it.annotationType.resolve().declaration.qualifiedName?.asString() == "sk.ainet.lang.trace.Diff" + } ?: return null + return diff.arguments.find { it.name?.asString() == "ruleName" }?.value as? String ?: "" + } + private data class DarcValidation( val validated: Boolean, val by: String, @@ -567,6 +590,13 @@ class OperatorDocProcessor( } append("]") + // Autodiff coverage. Always emitted so the manifest is the single source of + // differentiability truth; diffRuleName only when a custom @Diff(ruleName=...) is set. + append(",\n \"isDifferentiable\": ${function.isDifferentiable}") + if (function.diffRuleName.isNotEmpty()) { + append(",\n \"diffRuleName\": \"${escapeJson(function.diffRuleName)}\"") + } + // DARC validation block. Only emitted when an actual // @DarcValidated annotation is present, so unannotated // functions keep the JSON narrow.