diff --git a/kotlin-js-store/yarn.lock b/kotlin-js-store/yarn.lock index eeca6927..4efb8697 100644 --- a/kotlin-js-store/yarn.lock +++ b/kotlin-js-store/yarn.lock @@ -2110,7 +2110,12 @@ wrappy@1: resolved "https://registry.yarnpkg.com/wrappy/-/wrappy-1.0.2.tgz#b5243d8f3ec1aa35f1364605bc0d1036e30ab69f" integrity sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ== -ws@8.18.3, ws@~8.18.3: +ws@8.20.1: + version "8.20.1" + resolved "https://registry.yarnpkg.com/ws/-/ws-8.20.1.tgz#91a9ae2b312ccf98e0a85ec499b48cef45ab0ddb" + integrity sha512-It4dO0K5v//JtTXuPkfEOaI3uUN87iYPnqo/ZzqCoG3g8uhA66QUMs/SrM0YK7/NAu+r4LMh/9dq2A7k+rHs+w== + +ws@~8.18.3: version "8.18.3" resolved "https://registry.yarnpkg.com/ws/-/ws-8.18.3.tgz#b56b88abffde62791c639170400c93dcb0c95472" integrity sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg== diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt index 6aad624e..19f790fd 100644 --- a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt @@ -34,6 +34,7 @@ import sk.ainet.lang.types.FP16 import sk.ainet.lang.types.FP32 import sk.ainet.lang.types.Int32 import sk.ainet.lang.types.Int8 +import kotlin.math.floor import kotlin.math.ln import kotlin.math.log10 as kmLog10 import kotlin.math.log2 as kmLog2 @@ -1257,7 +1258,6 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory require(input.rank == 4) { "upsample2d: input must be 4D (N, C, H, W)" } val (scaleH, scaleW) = scale require(scaleH > 0 && scaleW > 0) { "upsample2d: scale factors must be positive" } - require(mode == UpsampleMode.Nearest) { "upsample2d: only Nearest mode is implemented on CPU backend" } val n = input.shape[0] val c = input.shape[1] @@ -1267,16 +1267,61 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory val outW = inW * scaleW val outShape = Shape(n, c, outH, outW) - val outData = dataFactory.init(outShape, input.dtype) { idx -> - val oh = idx[2] - val ow = idx[3] - val ih = oh / scaleH - val iw = ow / scaleW - input.data.get(idx[0], idx[1], ih, iw) + val outData = when (mode) { + UpsampleMode.Nearest -> dataFactory.init(outShape, input.dtype) { idx -> + val oh = idx[2] + val ow = idx[3] + val ih = oh / scaleH + val iw = ow / scaleW + input.data.get(idx[0], idx[1], ih, iw) + } + + UpsampleMode.Bilinear -> { + require(input.dtype == FP32::class || input.dtype == FP16::class) { + "upsample2d: Bilinear mode is only implemented for float dtypes (got ${input.dtype})" + } + dataFactory.init(outShape, input.dtype) { idx -> + val b = idx[0] + val ch = idx[1] + val srcH = sourceCoord(idx[2], scaleH, inH, alignCorners) + val srcW = sourceCoord(idx[3], scaleW, inW, alignCorners) + val ih0 = floor(srcH).toInt().coerceIn(0, inH - 1) + val ih1 = (ih0 + 1).coerceIn(0, inH - 1) + val iw0 = floor(srcW).toInt().coerceIn(0, inW - 1) + val iw1 = (iw0 + 1).coerceIn(0, inW - 1) + val wh = (srcH - ih0).coerceIn(0.0f, 1.0f) + val ww = (srcW - iw0).coerceIn(0.0f, 1.0f) + val v00 = (input.data.get(b, ch, ih0, iw0) as Number).toFloat() + val v01 = (input.data.get(b, ch, ih0, iw1) as Number).toFloat() + val v10 = (input.data.get(b, ch, ih1, iw0) as Number).toFloat() + val v11 = (input.data.get(b, ch, ih1, iw1) as Number).toFloat() + val blend = v00 * (1f - wh) * (1f - ww) + + v01 * (1f - wh) * ww + + v10 * wh * (1f - ww) + + v11 * wh * ww + @Suppress("UNCHECKED_CAST") + (blend as V) + } + } } return newTensor(outData, input.dtype, input) } + /** + * Maps an output coordinate to the (fractional) source coordinate for upsampling, + * matching the PyTorch convention. With [alignCorners] = false the sample centers are + * `(o + 0.5) / scale - 0.5`; with align corners the endpoints are pinned via + * `o * (in - 1) / (out - 1)`. The result may fall outside `[0, in-1]`; callers clamp. + */ + private fun sourceCoord(out: Int, scale: Int, inDim: Int, alignCorners: Boolean): Float { + val outDim = inDim * scale + return if (alignCorners) { + if (outDim <= 1) 0f else out.toFloat() * (inDim - 1) / (outDim - 1) + } else { + (out + 0.5f) / scale - 0.5f + } + } + @TensorOp() @InProgress("cpu", owner = "team:cpu", issue = "task-ops.md#op-maxpool2d") override fun maxPool2d( diff --git a/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/sk/ainet/exec/tensor/ops/DefaultCpuOpsUpsampleTest.kt b/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/sk/ainet/exec/tensor/ops/DefaultCpuOpsUpsampleTest.kt index c045dfe5..c91da2f7 100644 --- a/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/sk/ainet/exec/tensor/ops/DefaultCpuOpsUpsampleTest.kt +++ b/skainet-backends/skainet-backend-cpu/src/commonTest/kotlin/sk/ainet/sk/ainet/exec/tensor/ops/DefaultCpuOpsUpsampleTest.kt @@ -36,4 +36,33 @@ class DefaultCpuOpsUpsampleTest { assertEquals(4f, upsampled.data[0, 0, 3, 3]) } } + + @Test + fun bilinear_mode_blends_neighbors() { + data(ctx) { _ -> + // input rows [1,2] / [3,4]; bilinear 2x2 with PyTorch align_corners=false. + val input = tensor { + shape(1, 1, 2, 2) { + init { idx -> (1 + idx[2] * 2 + idx[3]).toFloat() } + } + } + + val upsampled = ops.upsample2d( + input = input, + scale = 2 to 2, + mode = UpsampleMode.Bilinear, + alignCorners = false + ) + + assertEquals(Shape(1, 1, 4, 4), upsampled.shape) + // Corners clamp to the source corner values. + assertEquals(1f, upsampled.data[0, 0, 0, 0], 1e-5f) + assertEquals(2f, upsampled.data[0, 0, 0, 3], 1e-5f) + assertEquals(3f, upsampled.data[0, 0, 3, 0], 1e-5f) + assertEquals(4f, upsampled.data[0, 0, 3, 3], 1e-5f) + // Interior blends: out[1,1] uses frac 0.25/0.25; out[2,2] uses 0.75/0.75. + assertEquals(1.75f, upsampled.data[0, 0, 1, 1], 1e-5f) + assertEquals(3.25f, upsampled.data[0, 0, 2, 2], 1e-5f) + } + } } diff --git a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt index 5e378baf..72ed17a5 100644 --- a/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt +++ b/skainet-compile/skainet-compile-dag/src/commonMain/kotlin/sk/ainet/lang/graph/DefaultExecutionTape.kt @@ -14,6 +14,7 @@ import sk.ainet.tape.GradientTape import sk.ainet.tape.RecordedOperation import sk.ainet.tape.TapeStack import kotlin.math.exp +import kotlin.math.floor import sk.ainet.lang.tensor.ops.AddOperation import sk.ainet.lang.tensor.ops.DivideOperation import sk.ainet.lang.tensor.ops.MatmulOperation @@ -817,7 +818,8 @@ public class DefaultGradientTape( val input = inputs[0] val scale = pair2(attributes["scale"], 1) val mode = (attributes["mode"] as? String) ?: "Nearest" - return listOf(upsample2dGrad(upstream, input, scale, mode)) + val alignCorners = (attributes["alignCorners"] as? Boolean) ?: false + return listOf(upsample2dGrad(upstream, input, scale, mode, alignCorners)) } override fun leakyReluBackward(upstream: Tensor, output: Tensor, inputs: List>, attributes: Map): List?> { @@ -1761,20 +1763,20 @@ public class DefaultGradientTape( } /** - * upsample2d backward (NEAREST only — the CPU forward only supports - * Nearest, so the backward mirrors that). For each input position, sum - * the upstream gradients of every output position it produced (the - * scaleH × scaleW block above-left of [ih*scaleH, iw*scaleW]). + * upsample2d backward — the transpose (scatter) of the forward sampler. + * Nearest: each input position sums the upstream gradients of every output + * position it produced (the scaleH × scaleW block above-left of + * [ih*scaleH, iw*scaleW]). Bilinear: each output gradient is distributed + * back to the same 4 source neighbors with the same bilinear weights used + * in the forward blend. */ private fun upsample2dGrad( upstream: Tensor, input: Tensor, scale: Pair, mode: String, + alignCorners: Boolean, ): Tensor { - require(mode.equals("Nearest", ignoreCase = true)) { - "upsample2dBackward: only Nearest mode implemented (got mode=$mode)" - } val n = input.shape[0] val c = input.shape[1] val inH = input.shape[2] @@ -1784,25 +1786,69 @@ public class DefaultGradientTape( val outW = upstream.shape[3] val dInput = zerosLike(input) - for (b in 0 until n) { - for (ch in 0 until c) { - for (oh in 0 until outH) { - val ih = oh / scaleH - if (ih !in 0 until inH) continue - for (ow in 0 until outW) { - val iw = ow / scaleW - if (iw !in 0 until inW) continue - val gOut = (upstream.data.get(b, ch, oh, ow) as Number).toFloat() - val cur = (dInput.data.get(b, ch, ih, iw) as Number).toFloat() - @Suppress("UNCHECKED_CAST") - dInput.data.set(b, ch, ih, iw, value = (cur + gOut) as Any) + fun accumulate(b: Int, ch: Int, ih: Int, iw: Int, delta: Float) { + val cur = (dInput.data.get(b, ch, ih, iw) as Number).toFloat() + @Suppress("UNCHECKED_CAST") + dInput.data.set(b, ch, ih, iw, value = (cur + delta) as Any) + } + + when (mode.lowercase()) { + "nearest" -> { + for (b in 0 until n) { + for (ch in 0 until c) { + for (oh in 0 until outH) { + val ih = oh / scaleH + if (ih !in 0 until inH) continue + for (ow in 0 until outW) { + val iw = ow / scaleW + if (iw !in 0 until inW) continue + val gOut = (upstream.data.get(b, ch, oh, ow) as Number).toFloat() + accumulate(b, ch, ih, iw, gOut) + } + } + } + } + } + + "bilinear" -> { + for (b in 0 until n) { + for (ch in 0 until c) { + for (oh in 0 until outH) { + val srcH = upsampleSourceCoord(oh, scaleH, inH, alignCorners) + val ih0 = floor(srcH).toInt().coerceIn(0, inH - 1) + val ih1 = (ih0 + 1).coerceIn(0, inH - 1) + val wh = (srcH - ih0).coerceIn(0.0f, 1.0f) + for (ow in 0 until outW) { + val srcW = upsampleSourceCoord(ow, scaleW, inW, alignCorners) + val iw0 = floor(srcW).toInt().coerceIn(0, inW - 1) + val iw1 = (iw0 + 1).coerceIn(0, inW - 1) + val ww = (srcW - iw0).coerceIn(0.0f, 1.0f) + val gOut = (upstream.data.get(b, ch, oh, ow) as Number).toFloat() + accumulate(b, ch, ih0, iw0, gOut * (1f - wh) * (1f - ww)) + accumulate(b, ch, ih0, iw1, gOut * (1f - wh) * ww) + accumulate(b, ch, ih1, iw0, gOut * wh * (1f - ww)) + accumulate(b, ch, ih1, iw1, gOut * wh * ww) + } + } } } } + + else -> throw IllegalArgumentException("upsample2dBackward: unsupported mode '$mode'") } return dInput } + /** Output→source coordinate map for upsampling (PyTorch convention); see DefaultCpuOps.sourceCoord. */ + private fun upsampleSourceCoord(out: Int, scale: Int, inDim: Int, alignCorners: Boolean): Float { + val outDim = inDim * scale + return if (alignCorners) { + if (outDim <= 1) 0f else out.toFloat() * (inDim - 1) / (outDim - 1) + } else { + (out + 0.5f) / scale - 0.5f + } + } + private fun clampGrad(upstream: Tensor, input: Tensor, minVal: Float, maxVal: Float): Tensor { val matchedUpstream = matchShape(upstream, input) val gradOut = zerosLike(input) diff --git a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/ConvPoolBackwardTest.kt b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/ConvPoolBackwardTest.kt index f65e73cd..f8d0d00f 100644 --- a/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/ConvPoolBackwardTest.kt +++ b/skainet-compile/skainet-compile-dag/src/commonTest/kotlin/sk/ainet/exec/autograd/ConvPoolBackwardTest.kt @@ -222,4 +222,14 @@ class ConvPoolBackwardTest { x.ops.upsample2d(x, scale = 2 to 2, mode = UpsampleMode.Nearest, alignCorners = false) } } + + @Test + fun upsample2d_bilinear_backward_distributes_weights() { + assertGradMatchesFiniteDiff( + xShape = Shape(1, 1, 3, 3), + x0 = FloatArray(9) { (it - 4) * 0.25f }, + ) { _, x -> + x.ops.upsample2d(x, scale = 2 to 2, mode = UpsampleMode.Bilinear, alignCorners = false) + } + } } diff --git a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt index 551ce27f..d87e987d 100644 --- a/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt +++ b/skainet-compile/skainet-compile-hlo/src/commonMain/kotlin/sk/ainet/compile/hlo/converters/NeuralNetOperationsConverter.kt @@ -4,6 +4,7 @@ import sk.ainet.compile.hlo.ConversionContext import sk.ainet.compile.hlo.ConversionResult import sk.ainet.compile.hlo.StableHloOperationConverter import sk.ainet.lang.graph.GraphNode +import kotlin.math.floor /** * Converter for neural network operations. @@ -31,7 +32,9 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { "rmsNorm", "rms_norm", "RMSNorm", "RmsNorm", "groupNorm", "groupNormalization", "GroupNormalization", "group_norm", // Attention - "scaledDotProductAttention" + "scaledDotProductAttention", + // Upsampling / interpolation (Nearest + Bilinear) + "upsample2d", "Upsample2d", "upsample_2d" ) override fun convert( @@ -49,6 +52,7 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { "rmsnorm", "rms_norm" -> convertRmsNorm(node, operands, context) "groupnorm", "groupnormalization", "group_norm" -> convertGroupNorm(node, operands, context) "scaleddotproductattention" -> convertSdpa(node, operands, context) + "upsample2d", "upsample_2d" -> convertUpsample2d(node, operands, context) else -> ConversionResult.Unsupported( node.operation.name, "Operation not supported by NeuralNetOperationsConverter" @@ -696,6 +700,176 @@ public class NeuralNetOperationsConverter : StableHloOperationConverter { ) } + /** + * Lower Upsample2d to traceable StableHLO. Input is NCHW (rank-4); scale, + * mode and alignCorners are all static at trace time, so both modes lower + * to fixed shape/linear ops (no runtime index math, no custom_call): + * + * - Nearest: pixel replication via reshape -> broadcast_in_dim -> reshape, + * exactly matching the eager op out[oh,ow] = in[oh/sH, ow/sW]. + * + * - Bilinear: resize is a separable linear map, so we precompute the two + * resize matrices A_h [outH x inH] and A_w [outW x inW] (each row holds + * the two bilinear neighbor weights) as constants and apply them with two + * dot_general contractions. The weights are the same static floats the + * eager/numpy blend uses, so the result matches to fp tolerance. + */ + private fun convertUpsample2d( + node: GraphNode, + operands: List, + context: ConversionContext + ): ConversionResult { + if (operands.size != 1) { + return ConversionResult.Failure( + "Upsample2d operation requires exactly 1 operand (input), got ${operands.size}", + "Unsupported upsample2d arity for node ${node.id}" + ) + } + + val typeMapper = context.getTypeMapper() + val inputSpec = node.inputs.firstOrNull() + val outputSpec = node.outputs.firstOrNull() + val inputShape = inputSpec?.shape ?: outputSpec?.shape ?: emptyList() + if (inputShape.size != 4) { + return ConversionResult.Failure( + "Upsample2d requires a rank-4 NCHW input, got rank ${inputShape.size}", + "Unsupported upsample2d input rank for node ${node.id}" + ) + } + + val params = node.operation.parameters + val (scaleH, scaleW) = extractScalePair(params) + if (scaleH < 1 || scaleW < 1) { + return ConversionResult.Failure( + "Upsample2d requires positive integer scale, got [$scaleH, $scaleW]", + "Unsupported upsample2d scale for node ${node.id}" + ) + } + val mode = (params["mode"] as? String) ?: "Nearest" + val alignCorners = (params["alignCorners"] as? Boolean) ?: false + + val n = inputShape[0] + val c = inputShape[1] + val h = inputShape[2] + val w = inputShape[3] + val outH = h * scaleH + val outW = w * scaleW + + val elementType = inputSpec?.let { typeMapper.mapDType(it.dtype) } + ?: outputSpec?.let { typeMapper.mapDType(it.dtype) } + ?: "f32" + val inputType = inputSpec?.let { typeMapper.mapTensorType(it) } + ?: "tensor<${n}x${c}x${h}x${w}x$elementType>" + val outputType = outputSpec?.let { typeMapper.mapTensorType(it) } + ?: "tensor<${n}x${c}x${outH}x${outW}x$elementType>" + + val xInput = operands[0] + val operations = mutableListOf() + + return when (mode.lowercase()) { + "nearest" -> { + // Insert unit axes after H and W, replicate each pixel sH x sW, then + // collapse (H,sH)->H*sH and (W,sW)->W*sW. + val expandedType = "tensor<${n}x${c}x${h}x1x${w}x1x$elementType>" + val replicatedType = "tensor<${n}x${c}x${h}x${scaleH}x${w}x${scaleW}x$elementType>" + val expanded = context.nextTempValue() + val replicated = context.nextTempValue() + val result = context.nextTempValue() + operations += "$expanded = stablehlo.reshape $xInput : ($inputType) -> $expandedType" + operations += "$replicated = stablehlo.broadcast_in_dim $expanded, " + + "dims = [0, 1, 2, 3, 4, 5] : ($expandedType) -> $replicatedType" + operations += "$result = stablehlo.reshape $replicated : ($replicatedType) -> $outputType" + operations.forEach { context.emitOperation(it) } + ConversionResult.Success(outputValueName = result, emittedOperations = operations) + } + + "bilinear" -> { + val ah = buildResizeMatrix(h, scaleH, alignCorners) // [outH x inH] + val aw = buildResizeMatrix(w, scaleW, alignCorners) // [outW x inW] + val ahType = "tensor<${outH}x${h}x$elementType>" + val awType = "tensor<${outW}x${w}x$elementType>" + // dot_general output layout = lhs-free ++ rhs-free, so contracting the + // input H axis against A_h yields [N, C, inW, outH]; then contracting that + // inW axis against A_w yields [N, C, outH, outW] — no transposes needed. + val intermediateType = "tensor<${n}x${c}x${w}x${outH}x$elementType>" + + val ahConst = context.nextTempValue() + val awConst = context.nextTempValue() + val tmp = context.nextTempValue() + val result = context.nextTempValue() + + operations += "$ahConst = stablehlo.constant dense<${denseMatrixLiteral(ah)}> : $ahType" + operations += "$awConst = stablehlo.constant dense<${denseMatrixLiteral(aw)}> : $awType" + operations += "$tmp = stablehlo.dot_general $xInput, $ahConst, " + + "contracting_dims = [2] x [1] : ($inputType, $ahType) -> $intermediateType" + operations += "$result = stablehlo.dot_general $tmp, $awConst, " + + "contracting_dims = [2] x [1] : ($intermediateType, $awType) -> $outputType" + operations.forEach { context.emitOperation(it) } + ConversionResult.Success(outputValueName = result, emittedOperations = operations) + } + + else -> ConversionResult.Failure( + "Upsample2d mode '$mode' is not supported (expected Nearest or Bilinear)", + "Unsupported upsample2d mode for node ${node.id}" + ) + } + } + + /** Read the [sH, sW] integer scale from op params (tape records List; also accept Pair/Int). */ + private fun extractScalePair(params: Map): Pair { + return when (val scale = params["scale"]) { + is Pair<*, *> -> + ((scale.first as? Number)?.toInt() ?: 1) to ((scale.second as? Number)?.toInt() ?: 1) + is Number -> scale.toInt() to scale.toInt() + is List<*> -> { + val list = scale.mapNotNull { (it as? Number)?.toInt() } + when { + list.size >= 2 -> list[0] to list[1] + list.size == 1 -> list[0] to list[0] + else -> 1 to 1 + } + } + else -> 1 to 1 + } + } + + /** + * Build the [outDim x inDim] bilinear resize matrix: row o holds the weights of + * the (at most two) source neighbors for output index o, matching the eager + * DefaultCpuOps coordinate map and border clamping. When the two neighbors clamp + * to the same index their weights sum to 1. + */ + private fun buildResizeMatrix(inDim: Int, scale: Int, alignCorners: Boolean): Array { + val outDim = inDim * scale + val m = Array(outDim) { FloatArray(inDim) } + for (o in 0 until outDim) { + val src = if (alignCorners) { + if (outDim <= 1) 0f else o.toFloat() * (inDim - 1) / (outDim - 1) + } else { + (o + 0.5f) / scale - 0.5f + } + val i0 = floor(src).toInt().coerceIn(0, inDim - 1) + val i1 = (i0 + 1).coerceIn(0, inDim - 1) + val frac = (src - i0).coerceIn(0.0f, 1.0f) + m[o][i0] += (1f - frac) + m[o][i1] += frac + } + return m + } + + /** Render a 2D float matrix as a nested-bracket MLIR `dense<...>` literal. */ + private fun denseMatrixLiteral(m: Array): String { + return m.joinToString(prefix = "[", postfix = "]", separator = ", ") { row -> + row.joinToString(prefix = "[", postfix = "]", separator = ", ") { v -> formatMlirFloat(v) } + } + } + + /** Float -> MLIR f-literal. Bilinear weights lie in [0,1] and render as plain decimals. */ + private fun formatMlirFloat(v: Float): String { + val s = v.toString() + return if (s.contains('.') || s.contains('e') || s.contains('E')) s else "$s.0" + } + /** * Lower RMSNorm to real StableHLO elementwise ops. This is the * normalization every Llama / Mistral / Qwen / Gemma family diff --git a/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/Upsample2dConverterTest.kt b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/Upsample2dConverterTest.kt new file mode 100644 index 00000000..d1029010 --- /dev/null +++ b/skainet-compile/skainet-compile-hlo/src/commonTest/kotlin/sk/ainet/compile/hlo/Upsample2dConverterTest.kt @@ -0,0 +1,114 @@ +package sk.ainet.compile.hlo + +import sk.ainet.lang.graph.DefaultComputeGraph +import sk.ainet.lang.graph.GraphEdge +import sk.ainet.lang.graph.GraphNode +import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.ops.Operation +import sk.ainet.lang.tensor.ops.TensorSpec +import sk.ainet.lang.tensor.ops.ValidationResult +import sk.ainet.lang.types.DType +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * Pins the new Upsample2d (NCHW) StableHLO lowering. `upsample2d` had no + * converter at all — `NeuralNetOperationsConverter` did not list it, so a + * traced node fell through to the "no converter found" path and could not + * export/compile. Both modes lower to traceable, compilable ops (no + * custom_call), because scale/mode/alignCorners are static at trace time: + * + * Nearest: reshape [N,C,H,W] -> [N,C,H,1,W,1] + * broadcast_in_dim -> [N,C,H,sH,W,sW] (pixel replication) + * reshape -> [N,C,H*sH, W*sW] + * + * Bilinear: resize = separable linear map; two constant resize matrices + * A_h [outH x inH], A_w [outW x inW] applied via dot_general. + */ +class Upsample2dConverterTest { + + @Test + fun upsample2d_nearest_lowers_to_reshape_broadcast_reshape() { + val graph = buildUpsampleGraph(mode = "Nearest", scaleH = 2, scaleW = 2) + val module = StableHloConverterFactory.createExtended().convert(graph, "test_upsample2d_nearest") + val mlir = module.content + println("[DEBUG_LOG] Upsample2d Nearest lowering:\n$mlir") + + assertTrue(mlir.contains("stablehlo.reshape"), "Nearest upsample must reshape to add and merge replication axes") + assertTrue(mlir.contains("stablehlo.broadcast_in_dim"), "Nearest upsample must replicate pixels via broadcast_in_dim") + assertFalse(mlir.contains("custom_call"), "upsample2d must not fall back to a custom_call stub") + assertFalse(mlir.contains("Operation not supported"), "upsample2d must be routed to a converter") + // (1,3,2,2) scale (2,2) -> 6D intermediate (1,3,2,2,2,2) and output (1,3,4,4). + assertTrue(mlir.contains("1x3x2x2x2x2xf32"), "Nearest upsample must build the 6D replication intermediate") + assertTrue(mlir.contains("1x3x4x4xf32"), "Nearest upsample output must be the upsampled NCHW shape") + } + + @Test + fun upsample2d_bilinear_lowers_to_constant_resize_matmuls() { + val graph = buildUpsampleGraph(mode = "Bilinear", scaleH = 2, scaleW = 2) + val module = StableHloConverterFactory.createExtended().convert(graph, "test_upsample2d_bilinear") + val mlir = module.content + println("[DEBUG_LOG] Upsample2d Bilinear lowering:\n$mlir") + + assertFalse(mlir.contains("custom_call"), "Bilinear upsample must lower to real ops, not a custom_call stub") + assertFalse(mlir.contains("Operation not supported"), "Bilinear upsample must be routed to a converter") + assertTrue(mlir.contains("stablehlo.constant dense<"), "Bilinear upsample must emit the constant resize matrices") + assertTrue(mlir.contains("stablehlo.dot_general"), "Bilinear upsample must apply resize matrices via dot_general") + // A_h is [outH x inH] = 4x2, A_w is [outW x inW] = 4x2; output is (1,3,4,4). + assertTrue(mlir.contains("4x2xf32"), "Bilinear upsample must build [out x in] resize matrices") + assertTrue(mlir.contains("1x3x4x4xf32"), "Bilinear upsample output must be the upsampled NCHW shape") + } + + // input (N=1, C=3, H=2, W=2), scale (sH, sW) -> output (1, 3, 2*sH, 2*sW). + private fun buildUpsampleGraph(mode: String, scaleH: Int, scaleW: Int): DefaultComputeGraph { + val graph = DefaultComputeGraph() + val inShape = listOf(1, 3, 2, 2) + val outShape = listOf(1, 3, 2 * scaleH, 2 * scaleW) + + val input = GraphNode( + id = "x", + operation = markerInputOp(), + inputs = emptyList(), + outputs = listOf(TensorSpec("x", inShape, "FP32")) + ) + graph.addNode(input) + + val up = GraphNode( + id = "up1", + operation = upsampleOp(scaleH, scaleW, mode), + inputs = listOf(TensorSpec("x", inShape, "FP32")), + outputs = listOf(TensorSpec("y", outShape, "FP32")) + ) + graph.addNode(up) + graph.addEdge(GraphEdge("e1", input, up, 0, 0, input.outputs[0])) + return graph + } + + private fun markerInputOp(): Operation = object : Operation { + override val name: String = "input" + override val type: String = "input" + override val parameters: Map = emptyMap() + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = emptyList() + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf("name" to name, "type" to type) + } + + private fun upsampleOp(scaleH: Int, scaleW: Int, mode: String): Operation = object : Operation { + override val name: String = "upsample2d" + override val type: String = "nn" + override val parameters: Map = + mapOf("scale" to listOf(scaleH, scaleW), "mode" to mode, "alignCorners" to false) + override fun execute(inputs: List>): List> = + throw UnsupportedOperationException("test fixture only") + override fun validateInputs(inputs: List): ValidationResult = ValidationResult.Valid + override fun inferOutputs(inputs: List): List = inputs.take(1) + override fun clone(newParameters: Map): Operation = this + override fun serialize(): Map = mapOf( + "name" to name, "type" to type, "parameters" to parameters + ) + } +}