Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@
},
"description": "Array of notes associated with the function"
},
"isDifferentiable": {
"type": "boolean",
"description": "Whether the op carries @Diff, i.e. has a generated backward-rule contract and must be wired into the autodiff dispatch. Sourced from the @Diff annotation."
},
"diffRuleName": {
"type": "string",
"description": "Custom adjoint rule name when @Diff(ruleName=...) is set; omitted for bare @Diff (rule name defaults to the op name)."
},
"validated": {
"type": "boolean",
"description": "Whether the function's documentation has been DARC-validated by a reviewer. Sourced from the @DarcValidated annotation on the function."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,12 @@ internal class RecordingTensorOpsDecorator(private val base: TensorOps) : Tensor
override fun <T : DType, V> indexSelect(input: Tensor<T, V>, indices: Tensor<DType, *>, dim: Int): Tensor<T, V> = base.indexSelect(input, indices, dim)
override fun <T : DType, V> exp(tensor: Tensor<T, V>): Tensor<T, V> = base.exp(tensor)
override fun <T : DType, V> expm1(tensor: Tensor<T, V>): Tensor<T, V> = base.expm1(tensor)
override fun <T : DType, V> sin(tensor: Tensor<T, V>): Tensor<T, V> = base.sin(tensor)
override fun <T : DType, V> cos(tensor: Tensor<T, V>): Tensor<T, V> = base.cos(tensor)
override fun <T : DType, V> convTranspose1d(
input: Tensor<T, V>, weight: Tensor<T, V>, bias: Tensor<T, V>?,
stride: Int, padding: Int, outputPadding: Int, dilation: Int, groups: Int
): Tensor<T, V> = base.convTranspose1d(input, weight, bias, stride, padding, outputPadding, dilation, groups)
override fun <T : DType, V> scaledDotProductAttention(
query: Tensor<T, V>, key: Tensor<T, V>, value: Tensor<T, V>,
mask: Tensor<T, V>?, scale: Float, causal: Boolean
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ private class TestTensorOps : TensorOps {
override fun <T : DType, V> indexSelect(input: Tensor<T, V>, indices: Tensor<DType, *>, dim: Int): Tensor<T, V> = input
override fun <T : DType, V> exp(tensor: Tensor<T, V>): Tensor<T, V> = tensor
override fun <T : DType, V> expm1(tensor: Tensor<T, V>): Tensor<T, V> = tensor
override fun <T : DType, V> sin(tensor: Tensor<T, V>): Tensor<T, V> = tensor
override fun <T : DType, V> cos(tensor: Tensor<T, V>): Tensor<T, V> = tensor
override fun <T : DType, V> convTranspose1d(input: Tensor<T, V>, weight: Tensor<T, V>, bias: Tensor<T, V>?, stride: Int, padding: Int, outputPadding: Int, dilation: Int, groups: Int): Tensor<T, V> = input
override fun <T : DType, V> scaledDotProductAttention(query: Tensor<T, V>, key: Tensor<T, V>, value: Tensor<T, V>, mask: Tensor<T, V>?, scale: Float, causal: Boolean): Tensor<T, V> = query
override fun <T : DType, V> conv1d(input: Tensor<T, V>, weight: Tensor<T, V>, bias: Tensor<T, V>?, stride: Int, padding: Int, dilation: Int, groups: Int): Tensor<T, V> = input
override fun <T : DType, V> conv2d(input: Tensor<T, V>, weight: Tensor<T, V>, bias: Tensor<T, V>?, stride: Pair<Int, Int>, padding: Pair<Int, Int>, dilation: Pair<Int, Int>, groups: Int): Tensor<T, V> = input
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package sk.ainet.exec.autograd

import sk.ainet.lang.graph.DefaultGradientTape
import sk.ainet.lang.tensor.ops.DifferentiableTensorOpsRules
import kotlin.test.Test
import kotlin.test.assertTrue

/**
* Autodiff-coverage guard: every op marked `@Diff` (the KSP-generated
* [DifferentiableTensorOpsRules.ruleNames]) must have a wired backward dispatch arm in
* [DefaultGradientTape] (its `dispatchedOpNames`).
*
* The `@Diff` → generated `DifferentiableTensorOps` interface already forces a backward *formula*
* to exist (compile error otherwise). This test closes the remaining link: that the formula is
* actually *reachable* from the trace dispatch. It would have caught the historical bug where
* `elu`/`leakyRelu`/`permute` had correct backward formulas that were never wired into the
* dispatch, so their gradients were silently dropped.
*/
class AutodiffCoverageTest {

@Test
fun every_diff_op_has_a_wired_backward_dispatch() {
val dispatched = DefaultGradientTape().dispatchedOpNames
val missing = DifferentiableTensorOpsRules.ruleNames - dispatched
assertTrue(
missing.isEmpty(),
"These @Diff ops have a generated backward contract but no dispatch arm in " +
"DefaultExecutionTape.backwardDispatch (their gradients would silently drop to null): $missing",
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package sk.ainet.exec.autograd

import kotlin.math.abs
import kotlin.test.Test
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
import sk.ainet.context.Phase
import sk.ainet.exec.tensor.ops.DefaultCpuOps
import sk.ainet.lang.graph.DefaultComputeGraph
import sk.ainet.lang.graph.DefaultGradientTape
import sk.ainet.lang.graph.DefaultGraphExecutionContext
import sk.ainet.lang.tensor.Shape
import sk.ainet.lang.tensor.Tensor
import sk.ainet.lang.tensor.data.DenseTensorDataFactory
import sk.ainet.lang.tensor.data.FloatArrayTensorData
import sk.ainet.lang.tensor.withRequiresGrad
import sk.ainet.lang.trace.GraphSink
import sk.ainet.lang.types.DType
import sk.ainet.lang.types.FP32
import sk.ainet.lang.types.Int32

/**
* Finite-difference backward parity for the autodiff-gap-closing work: the three previously
* implemented-but-unwired activations (elu, leakyRelu, permute) and the newly-differentiable ops
* (cos, sin, tril, gather, indexSelect, unfold, convTranspose1d). Each analytic gradient is compared
* to central finite difference of a sum-reduced output. Tolerance is generous (FP32 noise);
* correctness, not precision.
*/
class OpsAutodiffBackwardTest {

private fun ctx(): DefaultGraphExecutionContext {
val dataFactory = DenseTensorDataFactory()
val cpuOps = DefaultCpuOps(dataFactory)
val graph = DefaultComputeGraph()
return DefaultGraphExecutionContext(
baseOps = cpuOps,
phase = Phase.TRAIN,
tensorDataFactory = dataFactory,
createTapeFactory = { _ -> DefaultGradientTape(true) },
computeGraph = graph,
baseSink = GraphSink(graph),
)
}

private fun floatTensor(c: DefaultGraphExecutionContext, shape: Shape, values: FloatArray): Tensor<FP32, Float> =
c.fromFloatArray(shape, FP32::class, values)

private fun intTensor(c: DefaultGraphExecutionContext, shape: Shape, values: IntArray): Tensor<DType, Any> {
@Suppress("UNCHECKED_CAST")
return c.fromIntArray<Int32, Int>(shape, Int32::class, values) as Tensor<DType, Any>
}

private fun buf(t: Tensor<*, *>): FloatArray = (t.data as FloatArrayTensorData<*>).buffer

private fun FloatArray.sumElems(): Float {
var s = 0f
for (v in this) s += v
return s
}

private fun assertGradMatchesFiniteDiff(
xShape: Shape,
x0: FloatArray,
eps: Float = 1e-3f,
tol: Float = 3e-2f,
f: (DefaultGraphExecutionContext, Tensor<FP32, Float>) -> Tensor<FP32, Float>,
) {
val ctx = ctx()
val x = floatTensor(ctx, xShape, x0.copyOf()).withRequiresGrad()
val pair = ctx.record {
val out = f(this, x)
out.ops.sum(out)
}
val sumOutput = pair.second
val tape = pair.first as DefaultGradientTape
tape.computeGradients(targets = listOf(sumOutput), sources = listOf(x))
val analyticGrad = x.grad
assertNotNull(analyticGrad, "tape should populate x.grad")
val analytic = buf(analyticGrad)

for (i in x0.indices) {
val xPlus = x0.copyOf().also { it[i] += eps }
val xMinus = x0.copyOf().also { it[i] -= eps }
val ctxPlus = ctx()
val ctxMinus = ctx()
val fPlus = buf(f(ctxPlus, floatTensor(ctxPlus, xShape, xPlus))).sumElems()
val fMinus = buf(f(ctxMinus, floatTensor(ctxMinus, xShape, xMinus))).sumElems()
val fdGrad = (fPlus - fMinus) / (2 * eps)
val diff = abs(analytic[i] - fdGrad)
assertTrue(diff <= tol, "[$i] analytic=${analytic[i]} fd=$fdGrad diff=$diff tol=$tol")
}
}

// ── previously implemented-but-unwired (the silent-grad bug) ──────────────

@Test
fun elu_backward_matches_finite_diff() {
assertGradMatchesFiniteDiff(Shape(6), floatArrayOf(-1.5f, -0.4f, 0.3f, 0.9f, -0.7f, 1.2f), tol = 1e-2f) { _, x ->
x.ops.elu(x, alpha = 1.0f)
}
}

@Test
fun leakyRelu_backward_matches_finite_diff() {
assertGradMatchesFiniteDiff(Shape(6), floatArrayOf(-1.5f, -0.4f, 0.3f, 0.9f, -0.7f, 1.2f), tol = 1e-2f) { _, x ->
x.ops.leakyRelu(x, negativeSlope = 0.1f)
}
}

@Test
fun permute_backward_routes_axes_inverse() {
// sum(w ⊙ permute(x)) — the constant weight makes the upstream non-uniform, so a wrong
// inverse-axes would fail (a plain sum(permute(x)) has all-ones grad and can't detect it).
assertGradMatchesFiniteDiff(Shape(2, 3), FloatArray(6) { (it - 2) * 0.3f }, tol = 1e-2f) { c, x ->
val p = x.ops.permute(x, intArrayOf(1, 0)) // [2,3] -> [3,2]
val w = floatTensor(c, Shape(3, 2), floatArrayOf(1f, 2f, 3f, 4f, 5f, 6f))
x.ops.multiply(p, w)
}
}

// ── trivial new diffs ─────────────────────────────────────────────────────

@Test
fun sin_backward_matches_finite_diff() {
assertGradMatchesFiniteDiff(Shape(5), floatArrayOf(-1f, -0.3f, 0.2f, 0.8f, 1.4f), tol = 1e-2f) { _, x -> x.ops.sin(x) }
}

@Test
fun cos_backward_matches_finite_diff() {
assertGradMatchesFiniteDiff(Shape(5), floatArrayOf(-1f, -0.3f, 0.2f, 0.8f, 1.4f), tol = 1e-2f) { _, x -> x.ops.cos(x) }
}

@Test
fun tril_backward_masks_upper_triangle() {
// grad of sum(tril(x)) is the lower-triangular mask — position-dependent, so a no-op
// backward (passing the full upstream) would fail.
assertGradMatchesFiniteDiff(Shape(3, 3), FloatArray(9) { (it - 4) * 0.2f }, tol = 1e-2f) { _, x -> x.ops.tril(x, 0) }
}

// ── structural new diffs (scatter-add / fold / conv-transpose) ─────────────

@Test
fun gather_backward_scatter_adds_rows() {
// table [vocab=4, emb=3], indices [0,2,2,1] -> row gradients = gather counts (1,1,2,0).
assertGradMatchesFiniteDiff(Shape(4, 3), FloatArray(12) { (it - 6) * 0.1f }) { c, x ->
val idx = intTensor(c, Shape(4), intArrayOf(0, 2, 2, 1))
x.ops.gather(x, idx, dim = 0)
}
}

@Test
fun indexSelect_backward_scatter_adds_along_dim() {
// x [3,4], dim=1, indices [0,2,2] -> col gradients = select counts (1,0,2,0).
assertGradMatchesFiniteDiff(Shape(3, 4), FloatArray(12) { (it - 6) * 0.1f }) { c, x ->
val idx = intTensor(c, Shape(3), intArrayOf(0, 2, 2))
x.ops.indexSelect(x, idx, dim = 1)
}
}

@Test
fun unfold_backward_folds_overlapping_windows() {
// x [6], size 3, step 1 -> 4 windows; each element's grad = number of windows covering it.
assertGradMatchesFiniteDiff(Shape(6), FloatArray(6) { (it - 3) * 0.25f }) { _, x ->
x.ops.unfold(x, dim = 0, size = 3, step = 1)
}
}

@Test
fun convTranspose1d_backward_matches_finite_diff() {
// input [1,1,3], weight [in=1, outPerGroup=1, k=2], stride 1, padding 0.
val w = floatArrayOf(0.5f, -1.2f)
assertGradMatchesFiniteDiff(Shape(1, 1, 3), floatArrayOf(0.3f, -0.2f, 0.7f)) { c, x ->
val wT = floatTensor(c, Shape(1, 1, 2), w)
x.ops.convTranspose1d(x, wT, null, stride = 1, padding = 0, outputPadding = 0, dilation = 1, groups = 1)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public interface TensorOps {
): Tensor<T, V>

// Transposed convolutional operations
@Diff
public fun <T : DType, V> convTranspose1d(
input: Tensor<T, V>,
weight: Tensor<T, V>,
Expand All @@ -117,9 +118,7 @@ public interface TensorOps {
outputPadding: Int = 0,
dilation: Int = 1,
groups: Int = 1
): Tensor<T, V> {
throw NotImplementedError("convTranspose1d not implemented by this TensorOps backend")
}
): Tensor<T, V>

// Pooling operations
@Diff
Expand Down Expand Up @@ -251,20 +250,22 @@ public interface TensorOps {

/** Extract sliding windows of size [size] along dimension [dim] with stride [step].
* Result has one extra dimension appended containing the window elements. */
@Diff
public fun <T : DType, V> unfold(tensor: Tensor<T, V>, dim: Int, size: Int, step: Int): Tensor<T, V>

// Comparison operations (return mask tensors with 1.0 where true, 0.0 where false)

/** Element-wise less than: x < value → 1.0, else 0.0 */
/** Element-wise less than: x < value → 1.0, else 0.0. Non-differentiable by design (boolean mask). */
public fun <T : DType, V> lt(tensor: Tensor<T, V>, value: Float): Tensor<T, V>

/** Element-wise greater than or equal: x >= value → 1.0, else 0.0 */
/** Element-wise greater than or equal: x >= value → 1.0, else 0.0. Non-differentiable by design (boolean mask). */
public fun <T : DType, V> ge(tensor: Tensor<T, V>, value: Float): Tensor<T, V>

// Matrix utilities
@Diff
public fun <T : DType, V> tril(tensor: Tensor<T, V>, k: Int = 0): Tensor<T, V>

// Type conversion operations
// Type conversion operations. Non-differentiable by design (dtype cast).
public fun <TFrom : DType, TTo : DType, V> convert(
tensor: Tensor<TFrom, V>,
targetType: TTo
Expand All @@ -273,15 +274,19 @@ public interface TensorOps {
// --- LLM / Transformer primitives ---

/** Gather rows from [input] along [dim] using integer [indices].
* Primary use: embedding lookup (dim=0, indices=token IDs). */
* Primary use: embedding lookup (dim=0, indices=token IDs).
* Differentiable w.r.t. [input] only (scatter-add); [indices] are discrete. */
@Diff
public fun <T : DType, V> gather(
input: Tensor<T, V>,
indices: Tensor<DType, *>,
dim: Int = 0
): Tensor<T, V>

/** Select elements from [input] along [dim] at the given [indices].
* Similar to gather but for general index selection patterns. */
* Similar to gather but for general index selection patterns.
* Differentiable w.r.t. [input] only (scatter-add); [indices] are discrete. */
@Diff
public fun <T : DType, V> indexSelect(
input: Tensor<T, V>,
indices: Tensor<DType, *>,
Expand All @@ -299,13 +304,11 @@ public interface TensorOps {
public fun <T : DType, V> expm1(tensor: Tensor<T, V>): Tensor<T, V>

// Trigonometric operations
public fun <T : DType, V> sin(tensor: Tensor<T, V>): Tensor<T, V> {
throw NotImplementedError("sin not implemented by this TensorOps backend")
}
@Diff
public fun <T : DType, V> sin(tensor: Tensor<T, V>): Tensor<T, V>

public fun <T : DType, V> cos(tensor: Tensor<T, V>): Tensor<T, V> {
throw NotImplementedError("cos not implemented by this TensorOps backend")
}
@Diff
public fun <T : DType, V> cos(tensor: Tensor<T, V>): Tensor<T, V>

/**
* Scaled dot-product attention.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class TracingWrapperGenerator(
val differentiableMethods = methods.filter { it.isDifferentiable }
if (differentiableMethods.isNotEmpty()) {
generateDifferentiableOps(interfaceDeclaration, differentiableMethods)
generateDifferentiableOpsRules(interfaceDeclaration, differentiableMethods)
}

logger.info("Successfully generated $generatedClassName.kt in package $packageName")
Expand Down Expand Up @@ -117,7 +118,47 @@ class TracingWrapperGenerator(
outputStream.write(code.toByteArray())
}
}


/**
* Generates the authoritative set of differentiable op rule-names (one per @Diff method).
* This is the machine-readable companion to the [generateDifferentiableOps] interface: the
* interface forces every @Diff op to have a backward *formula* (compile error otherwise), while
* this set lets a downstream test assert every @Diff op is also *dispatched* in the execution
* tape — closing the contract ⟷ formula ⟷ wiring loop so none can silently drift.
*/
private fun generateDifferentiableOpsRules(
interfaceDeclaration: KSClassDeclaration,
methods: List<MethodInfo>
) {
val packageName = interfaceDeclaration.packageName.asString()
val objectName = "Differentiable${interfaceDeclaration.simpleName.asString()}Rules"
val ruleNames = methods.map { it.diffRuleName ?: it.name }.distinct().sorted()

val file = codeGenerator.createNewFile(
dependencies = Dependencies(false, interfaceDeclaration.containingFile!!),
packageName = packageName,
fileName = objectName
)

file.use { outputStream ->
val code = buildString {
appendLine("package $packageName")
appendLine()
appendLine("/**")
appendLine(" * Authoritative set of differentiable op rule-names, generated from @Diff annotations on")
appendLine(" * ${interfaceDeclaration.simpleName.asString()}. Used by the autodiff-coverage guard to verify every")
appendLine(" * differentiable op has a wired backward rule. Do not edit by hand.")
appendLine(" */")
appendLine("public object $objectName {")
appendLine(" public val ruleNames: Set<String> = setOf(")
ruleNames.forEach { appendLine(" \"$it\",") }
appendLine(" )")
appendLine("}")
}
outputStream.write(code.toByteArray())
}
}

/**
* Validates that code generation is possible for all methods.
*/
Expand Down Expand Up @@ -294,7 +335,7 @@ class TracingWrapperGenerator(
// Fix the type for nullable parameters that should have ? suffix
val correctedType = when {
param.name == "dim" && method.name in listOf("squeeze", "sum", "mean", "variance") && !param.type.endsWith("?") -> "${param.type}?"
param.name == "bias" && method.name in listOf("conv1d", "conv2d", "conv3d") && !param.type.endsWith("?") -> "${param.type}?"
param.name == "bias" && method.name in listOf("conv1d", "conv2d", "conv3d", "convTranspose1d") && !param.type.endsWith("?") -> "${param.type}?"
param.name == "mask" && method.name == "scaledDotProductAttention" && !param.type.endsWith("?") -> "${param.type}?"
else -> param.type
}
Expand Down
Loading
Loading