diff --git a/README.md b/README.md index 16d0d8de..d11dc9bc 100644 --- a/README.md +++ b/README.md @@ -224,8 +224,35 @@ Runnable examples: - Built-in loaders: MNIST, Fashion-MNIST, CIFAR-10 - URI-backed data sources: `file://`, `https://`, `hf+https://`, and `hf://...` +- Dataset operations: deterministic shuffle/split, stratified split, filter/map/transform views, batch flows, and epoch flows +- Raw dataset parsers: CSV, TSV, JSON arrays/objects, JSON Lines (`.jsonl`, `.ndjson`) +- Type-safe transform DSLs: image/tensor transforms plus suspendable raw data pipelines - Formats: GGUF, ONNX, SafeTensors, JSON, Image (JPEG, PNG) -- Type-safe transform DSL: resize, crop, normalize, toTensor + +```kotlin +val raw = JvmDataSourceResolver().rawDataset { + from("hf://datasets/org/repo@main/train.jsonl") + format(DataFormat.JSON_LINES) + cachePolicy(CachePolicy.Use) +} + +val withoutLabel = dataPipeline() + .stage( + dataTransformer( + name = "drop-label", + outputSchema = { schema -> DataSchema(schema.columns - "label") } + ) { dataset -> + val columns = dataset.schema.columns - "label" + dataset.copy( + schema = DataSchema(columns), + rows = dataset.rows.map { row -> + RawDataRow(row.values.filterKeys { key -> key in columns }) + } + ) + } + ) + .execute(raw) +``` ### Edge AI: Arduino / C99 Export diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt index 19f790fd..8aca6dc2 100644 --- a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt @@ -390,37 +390,39 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory require(a.dtype == b.dtype) { "DType mismatch: ${a.dtype} vs ${b.dtype}" } // Packed-quant fast path (FP32 input × packed weight), resolved via KernelRegistry. - chooseQuantizedMatmulHeap(a, b)?.let { return it } + KernelProfile.timeQuant { chooseQuantizedMatmulHeap(a, b) }?.let { return it } // Fast path: 2D × 2D with FloatArray backing — direct buffer access, no per-element allocation if (a.rank == 2 && b.rank == 2 && (a.dtype == FP32::class) && a.data is FloatArrayTensorData<*> && b.data is FloatArrayTensorData<*> ) { - val aBuf = (a.data as FloatArrayTensorData<*>).buffer - val bBuf = (b.data as FloatArrayTensorData<*>).buffer - val m = a.shape[0] - val k = a.shape[1] - val n = b.shape[1] - require(k == b.shape[0]) { "Matrix multiplication shape mismatch: ${a.shape} vs ${b.shape}" } - val out = FloatArray(m * n) - for (i in 0 until m) { - val aOff = i * k - for (j in 0 until n) { - var sum = 0f - for (p in 0 until k) { - sum += aBuf[aOff + p] * bBuf[p * n + j] + return KernelProfile.timeFp32 { + val aBuf = (a.data as FloatArrayTensorData<*>).buffer + val bBuf = (b.data as FloatArrayTensorData<*>).buffer + val m = a.shape[0] + val k = a.shape[1] + val n = b.shape[1] + require(k == b.shape[0]) { "Matrix multiplication shape mismatch: ${a.shape} vs ${b.shape}" } + val out = FloatArray(m * n) + for (i in 0 until m) { + val aOff = i * k + for (j in 0 until n) { + var sum = 0f + for (p in 0 until k) { + sum += aBuf[aOff + p] * bBuf[p * n + j] + } + out[i * n + j] = sum } - out[i * n + j] = sum } + @Suppress("UNCHECKED_CAST") + val outData = dataFactory.fromFloatArray(Shape(m, n), a.dtype, out) as sk.ainet.lang.tensor.data.TensorData + newTensor(outData, a.dtype, a, b) } - @Suppress("UNCHECKED_CAST") - val outData = dataFactory.fromFloatArray(Shape(m, n), a.dtype, out) as sk.ainet.lang.tensor.data.TensorData - return newTensor(outData, a.dtype, a, b) } // Generic fallback for batched / non-float / non-2D cases - return matmulGeneric(a, b) + return KernelProfile.timeGeneric { matmulGeneric(a, b) } } private fun matmulGeneric( diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/KernelProfile.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/KernelProfile.kt new file mode 100644 index 00000000..0f20d85b --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/KernelProfile.kt @@ -0,0 +1,55 @@ +package sk.ainet.exec.tensor.ops + +import kotlin.time.TimeSource + +/** + * Lightweight, always-on accumulating profiler for the matmul dispatch paths. + * Diagnostic only — used to localize where native decode time goes (quant-NEON + * vs FP32-scalar vs generic) before investing in a kernel rewrite. The clock + * read per call is negligible next to a matmul. Read [report] after a run and + * [reset] between phases (e.g. to separate prefill from decode). + */ +public object KernelProfile { + private val clock = TimeSource.Monotonic + + public var quantNanos: Long = 0; private set + public var quantCalls: Long = 0; private set + public var fp32Nanos: Long = 0; private set + public var fp32Calls: Long = 0; private set + public var genericNanos: Long = 0; private set + public var genericCalls: Long = 0; private set + + public fun timeQuant(body: () -> R): R { + val mark = clock.markNow(); val r = body() + quantNanos += mark.elapsedNow().inWholeNanoseconds; quantCalls++; return r + } + + public fun timeFp32(body: () -> R): R { + val mark = clock.markNow(); val r = body() + fp32Nanos += mark.elapsedNow().inWholeNanoseconds; fp32Calls++; return r + } + + public fun timeGeneric(body: () -> R): R { + val mark = clock.markNow(); val r = body() + genericNanos += mark.elapsedNow().inWholeNanoseconds; genericCalls++; return r + } + + public fun reset() { + quantNanos = 0; quantCalls = 0 + fp32Nanos = 0; fp32Calls = 0 + genericNanos = 0; genericCalls = 0 + } + + public fun report(): String { + fun ms(ns: Long) = ns / 1_000_000.0 + val total = quantNanos + fp32Nanos + genericNanos + fun pct(ns: Long) = if (total > 0) 100.0 * ns / total else 0.0 + return buildString { + appendLine("[KernelProfile] matmul time breakdown:") + appendLine(" quant-NEON : ${ms(quantNanos)} ms over $quantCalls calls (${pct(quantNanos)}%)") + appendLine(" fp32-scalar : ${ms(fp32Nanos)} ms over $fp32Calls calls (${pct(fp32Nanos)}%)") + appendLine(" generic : ${ms(genericNanos)} ms over $genericCalls calls (${pct(genericNanos)}%)") + append(" matmul total : ${ms(total)} ms") + } + } +} diff --git a/skainet-backends/skainet-backend-native-cpu/native/src/q4k_matmul.c b/skainet-backends/skainet-backend-native-cpu/native/src/q4k_matmul.c index 8091e58a..21038410 100644 --- a/skainet-backends/skainet-backend-native-cpu/native/src/q4k_matmul.c +++ b/skainet-backends/skainet-backend-native-cpu/native/src/q4k_matmul.c @@ -3,6 +3,8 @@ #include #include +#include +#include #define Q4K_BLOCK_SIZE 256 #define Q4K_SUB_BLOCK_SIZE 32 @@ -60,21 +62,48 @@ static inline void skainet_q4k_decode_scales( } } +/* + * Quantize one 256-float input block to symmetric int8 (Q8) with a single + * per-block scale d_in = maxabs/127, q8[i] = round(in[i]/d_in). Returns d_in + * (0 if the block is all-zero, with q8 zeroed). Mirrors ggml's block_q8_K + * activation quantization — the source of the (small, well-understood) error + * vs the exact float kernel, and what unlocks the int8 dot-product fast path. + */ +static inline float skainet_q8_quantize_block(const float* SKAINET_RESTRICT in, int8_t* SKAINET_RESTRICT q8) { + float maxabs = 0.0f; + for (int i = 0; i < Q4K_BLOCK_SIZE; ++i) { + const float a = fabsf(in[i]); + if (a > maxabs) maxabs = a; + } + if (maxabs == 0.0f) { + for (int i = 0; i < Q4K_BLOCK_SIZE; ++i) q8[i] = 0; + return 0.0f; + } + const float d_in = maxabs / 127.0f; + const float inv = 127.0f / maxabs; + for (int i = 0; i < Q4K_BLOCK_SIZE; ++i) { + int v = (int) lrintf(in[i] * inv); + if (v > 127) v = 127; else if (v < -127) v = -127; + q8[i] = (int8_t) v; + } + return d_in; +} + /* * Native Q4_K matrix-vector multiply matching the - * sk.ainet.backend.api.kernel.Q4KMatmulKernel SPI contract. Single - * input row times an `outputDim x inputDim` Q4_K-packed weight tensor - * laid out (blockIdx * outputDim + o) * 144 bytes. - * - * Lazy-dmin pattern: per sub-block accumulate - * codeSum[s] = sum_i input[i] * code[i] - * inputSum[s] = sum_i input[i] - * and combine once via - * acc += d * scaleIdx[s] * codeSum[s] - dMin * minIdx[s] * inputSum[s] + * sk.ainet.backend.api.kernel.Q4KMatmulKernel SPI contract. Single input row + * times an `outputDim x inputDim` Q4_K-packed weight laid out + * (blockIdx * outputDim + o) * 144 bytes. * - * Scalar single-threaded for PR 2; the tight inner loop is - * straight-line FP arithmetic so -O3 auto-vectorizes the - * codeSum/inputSum accumulators on AVX2/NEON. + * Fused int8 dot path (ggml-style): the input row is quantized to Q8 ONCE per + * 256-block (reused across all output rows), then each weight sub-block is an + * int8 dot-product against the Q8 activation: + * acc += d_in[b] * ( d * Σ_s scaleIdx[s]*intDot[s] - dMin * Σ_s minIdx[s]*intSum[s] ) + * where intDot[s] = Σ q8[i]*code[i] and intSum[s] = Σ q8[i] over the sub-block. + * On AArch64 with dotprod (asimddp) the inner dot uses vdotq_s32 (16 int8 MACs + * per instruction); otherwise a scalar integer fallback (auto-vectorized). + * The index mapping (groups, lo/hi sub-blocks, input alignment) is identical to + * the previous float kernel, which was parity-checked against Panama. */ SKAINET_API void skainet_q4k_matmul( const float* SKAINET_RESTRICT input, @@ -92,86 +121,102 @@ SKAINET_API void skainet_q4k_matmul( const float* in_base = input + input_offset; float* out_base = output + output_offset; + /* Pre-quantize the whole input row to Q8 once (reused across all o). */ + int8_t* q8 = (int8_t*) malloc((size_t) input_dim * sizeof(int8_t)); + float* d_in = (float*) malloc((size_t) blocks_per_input_dim * sizeof(float)); + if (q8 == NULL || d_in == NULL) { free(q8); free(d_in); return; } + for (int32_t b = 0; b < blocks_per_input_dim; ++b) { + d_in[b] = skainet_q8_quantize_block(in_base + (size_t) b * Q4K_BLOCK_SIZE, + q8 + (size_t) b * Q4K_BLOCK_SIZE); + } + int scale_idx[Q4K_SUB_BLOCKS]; int min_idx[Q4K_SUB_BLOCKS]; - for (int32_t o = 0; o < output_dim; ++o) { - float acc = 0.0f; - - for (int32_t block_idx = 0; block_idx < blocks_per_input_dim; ++block_idx) { - const uint8_t* block = weight + weight_byte_offset - + (size_t)(block_idx * output_dim + o) * Q4K_BYTES_PER_BLOCK; - - /* d, dMin (FP16 LE -> FP32). */ + /* + * Loop order: block OUTER, output row INNER. The weight is packed + * block-major — (blockIdx * output_dim + o) * 144 — so for a fixed block, + * consecutive `o` are exactly 144 bytes apart: the weight bytes are read + * strictly sequentially (prefetch- and cache-line-friendly). The reverse + * order (o outer) strides output_dim*144 bytes per step (~295 KB on the + * down-proj), which on an in-order A55 with small caches makes every weight + * read a cold miss and dominates runtime regardless of inner-loop compute. + * out_base[o] is accumulated across blocks (output_dim*4 bytes stays hot in + * cache); the accumulation order over blocks is unchanged, so this is + * numerically identical to the o-outer form. + */ + for (int32_t o = 0; o < output_dim; ++o) out_base[o] = 0.0f; + + for (int32_t block_idx = 0; block_idx < blocks_per_input_dim; ++block_idx) { + const int8_t* q8_block = q8 + (size_t) block_idx * Q4K_BLOCK_SIZE; + const float di = d_in[block_idx]; + const uint8_t* block = weight + weight_byte_offset + + (size_t)(block_idx * output_dim) * Q4K_BYTES_PER_BLOCK; + + for (int32_t o = 0; o < output_dim; ++o, block += Q4K_BYTES_PER_BLOCK) { const uint16_t d_bits = (uint16_t) block[0] | ((uint16_t) block[1] << 8); const uint16_t d_min_bits = (uint16_t) block[2] | ((uint16_t) block[3] << 8); const float d = skainet_half_to_float(d_bits); const float d_min = skainet_half_to_float(d_min_bits); - /* 12 bytes of packed (scaleIdx, minIdx) -> 8 ints each. */ skainet_q4k_decode_scales(block + 4, scale_idx, min_idx); const uint8_t* qs = block + 16; - const float* in_block = in_base + (size_t) block_idx * Q4K_BLOCK_SIZE; - /* 4 strided qs groups; group j carries sub-blocks 2j (lo) and 2j+1 (hi). */ + int64_t block_scale_dot = 0; + int64_t block_min_sum = 0; + for (int group_j = 0; group_j < 4; ++group_j) { - const uint8_t* qs_group = qs + group_j * Q4K_SUB_BLOCK_SIZE; + const uint8_t* qs_group = qs + group_j * Q4K_SUB_BLOCK_SIZE; const int sb_lo = 2 * group_j; const int sb_hi = sb_lo + 1; - const float* in_lo = in_block + sb_lo * Q4K_SUB_BLOCK_SIZE; - const float* in_hi = in_block + sb_hi * Q4K_SUB_BLOCK_SIZE; + const int8_t* q8_lo = q8_block + sb_lo * Q4K_SUB_BLOCK_SIZE; + const int8_t* q8_hi = q8_block + sb_hi * Q4K_SUB_BLOCK_SIZE; - float code_sum_lo = 0.0f, input_sum_lo = 0.0f; - float code_sum_hi = 0.0f, input_sum_hi = 0.0f; + int32_t dot_lo = 0, sum_lo = 0, dot_hi = 0, sum_hi = 0; -#ifdef SKAINET_HAVE_NEON - float32x4_t cacc_lo = vdupq_n_f32(0.0f), iacc_lo = vdupq_n_f32(0.0f); - float32x4_t cacc_hi = vdupq_n_f32(0.0f), iacc_hi = vdupq_n_f32(0.0f); +#ifdef SKAINET_HAVE_DOTPROD + int32x4_t acc_dot_lo = vdupq_n_s32(0), acc_dot_hi = vdupq_n_s32(0); + int32_t acc_sum_lo = 0, acc_sum_hi = 0; for (int off = 0; off < Q4K_SUB_BLOCK_SIZE; off += 16) { const uint8x16_t packed = vld1q_u8(qs_group + off); - const uint8x16_t lo_nib = vandq_u8(packed, vdupq_n_u8(0x0F)); - const uint8x16_t hi_nib = vshrq_n_u8(packed, 4); - float32x4_t cl[4], ch[4]; - skainet_neon_u8x16_to_f32x4x4(lo_nib, cl); - skainet_neon_u8x16_to_f32x4x4(hi_nib, ch); - for (int q = 0; q < 4; ++q) { - const float32x4_t v_lo = vld1q_f32(in_lo + off + q * 4); - const float32x4_t v_hi = vld1q_f32(in_hi + off + q * 4); - cacc_lo = vfmaq_f32(cacc_lo, v_lo, cl[q]); - iacc_lo = vaddq_f32(iacc_lo, v_lo); - cacc_hi = vfmaq_f32(cacc_hi, v_hi, ch[q]); - iacc_hi = vaddq_f32(iacc_hi, v_hi); - } + const int8x16_t code_lo = vreinterpretq_s8_u8(vandq_u8(packed, vdupq_n_u8(0x0F))); + const int8x16_t code_hi = vreinterpretq_s8_u8(vshrq_n_u8(packed, 4)); + const int8x16_t a_lo = vld1q_s8(q8_lo + off); + const int8x16_t a_hi = vld1q_s8(q8_hi + off); + acc_dot_lo = vdotq_s32(acc_dot_lo, code_lo, a_lo); + acc_dot_hi = vdotq_s32(acc_dot_hi, code_hi, a_hi); + acc_sum_lo += vaddlvq_s8(a_lo); + acc_sum_hi += vaddlvq_s8(a_hi); } - code_sum_lo = skainet_neon_hadd_f32(cacc_lo); - input_sum_lo = skainet_neon_hadd_f32(iacc_lo); - code_sum_hi = skainet_neon_hadd_f32(cacc_hi); - input_sum_hi = skainet_neon_hadd_f32(iacc_hi); + dot_lo = vaddvq_s32(acc_dot_lo); + dot_hi = vaddvq_s32(acc_dot_hi); + sum_lo = acc_sum_lo; + sum_hi = acc_sum_hi; #else - /* 32 iterations — auto-vectorizes cleanly under -O3. */ for (int i = 0; i < Q4K_SUB_BLOCK_SIZE; ++i) { - const uint8_t b = qs_group[i]; - const float code_lo = (float)(b & 0x0F); - const float code_hi = (float)(b >> 4); - const float v_lo = in_lo[i]; - const float v_hi = in_hi[i]; - code_sum_lo += v_lo * code_lo; - input_sum_lo += v_lo; - code_sum_hi += v_hi * code_hi; - input_sum_hi += v_hi; + const uint8_t pb = qs_group[i]; + const int code_lo = (int)(pb & 0x0F); + const int code_hi = (int)(pb >> 4); + const int a_lo = (int) q8_lo[i]; + const int a_hi = (int) q8_hi[i]; + dot_lo += a_lo * code_lo; + sum_lo += a_lo; + dot_hi += a_hi * code_hi; + sum_hi += a_hi; } #endif - const float scale_lo = d * (float) scale_idx[sb_lo]; - const float offset_lo = d_min * (float) min_idx[sb_lo]; - const float scale_hi = d * (float) scale_idx[sb_hi]; - const float offset_hi = d_min * (float) min_idx[sb_hi]; - acc += code_sum_lo * scale_lo - input_sum_lo * offset_lo; - acc += code_sum_hi * scale_hi - input_sum_hi * offset_hi; + block_scale_dot += (int64_t) scale_idx[sb_lo] * dot_lo + + (int64_t) scale_idx[sb_hi] * dot_hi; + block_min_sum += (int64_t) min_idx[sb_lo] * sum_lo + + (int64_t) min_idx[sb_hi] * sum_hi; } - } - out_base[o] = acc; + out_base[o] += di * (d * (float) block_scale_dot - d_min * (float) block_min_sum); + } } + + free(q8); + free(d_in); } diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4KMatmulKernelParityTest.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4KMatmulKernelParityTest.kt index 8e1c9546..7cc278cf 100644 --- a/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4KMatmulKernelParityTest.kt +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4KMatmulKernelParityTest.kt @@ -72,14 +72,34 @@ class NativeQ4KMatmulKernelParityTest { nativeOut, 0, ) + // The native kernel quantizes the activation to int8 (Q8) for the dotprod + // fast path — deliberately lossy (ggml-style), so it is NOT bit-exact vs the + // float Panama reference. Per-row relative error is the WRONG gate here: + // on zero-mean uniform-random fixtures a row whose true value is ~0 shows an + // unbounded relative error from a tiny absolute one. The meaningful metric for + // a lossy kernel is the aggregate error energy — RMS(error) / RMS(signal) — + // which ggml-style validation uses. Real (smoother) LLM activations are far + // tighter than these worst-case fixtures; the end-to-end gate is the on-board + // generation output. + var sqErr = 0.0 + var sqSig = 0.0 for (o in 0 until outputDim) { - val diff = abs(refOut[o] - nativeOut[o]) - val rel = diff / (abs(refOut[o]) + 1e-9f) - assertTrue( - diff <= tol || rel < 1e-4f, - "row $o diverged: panama=${refOut[o]} native=${nativeOut[o]} diff=$diff rel=$rel tol=$tol", - ) + val d = (refOut[o] - nativeOut[o]).toDouble() + sqErr += d * d + sqSig += refOut[o].toDouble() * refOut[o].toDouble() } + val rmsErr = kotlin.math.sqrt(sqErr / outputDim) + val rmsSig = kotlin.math.sqrt(sqSig / outputDim) + val relRms = rmsErr / (rmsSig + 1e-9) + assertTrue( + relRms < AGG_REL_TOL || rmsErr < tol, + "Q8 parity exceeded: relRms=$relRms (rmsErr=$rmsErr rmsSig=$rmsSig) over $outputDim rows, tol=$AGG_REL_TOL", + ) + } + + private companion object { + // Aggregate Q8-activation RMS-relative-error bound (uniform-random worst case). + const val AGG_REL_TOL = 0.03 } @Test diff --git a/skainet-data/skainet-data-api/build.gradle.kts b/skainet-data/skainet-data-api/build.gradle.kts index 175b2f3f..511226f9 100644 --- a/skainet-data/skainet-data-api/build.gradle.kts +++ b/skainet-data/skainet-data-api/build.gradle.kts @@ -46,11 +46,13 @@ kotlin { val commonMain by getting { dependencies { implementation(project(":skainet-lang:skainet-lang-core")) + implementation(libs.kotlinx.coroutines) } } commonTest.dependencies { implementation(libs.kotlin.test) + implementation(libs.kotlinx.coroutines.test) // implementation(project(":skainet-core:skainet-performance")) } } diff --git a/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/DataBatch.kt b/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/DataBatch.kt index 4183ff89..51617053 100644 --- a/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/DataBatch.kt +++ b/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/DataBatch.kt @@ -1,10 +1,37 @@ package sk.ainet.data import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.sliceView import sk.ainet.lang.types.DType -public data class DataBatch(val x: Array>, val y: Tensor) { +public data class DataBatch( + val x: Array>, + val y: Tensor, + val indices: IntArray = IntArray(y.shape[0]) { it }, + val metadata: Map = emptyMap() +) { + /** Number of samples represented by this batch. */ + public val batchSize: Int get() = indices.size + + /** Returns a copy with [metadata] merged into the existing metadata map. */ + public fun withMetadata(metadata: Map): DataBatch = + copy(metadata = this.metadata + metadata) + + /** Returns a contiguous slice of this batch over its leading batch dimension. */ + public fun slice(range: IntRange): DataBatch { + require(!range.isEmpty()) { "range must not be empty" } + require(range.first >= 0) { "range start must be non-negative" } + require(range.last < batchSize) { "range end must be within batch bounds" } + + val endExclusive = range.last + 1 + return copy( + x = x.map { tensor -> tensor.sliceLeadingDimension(range.first, endExclusive) }.toTypedArray(), + y = y.sliceLeadingDimension(range.first, endExclusive), + indices = indices.sliceArray(range) + ) + } + override fun equals(other: Any?): Boolean { if (this === other) return true if (other == null || this::class != other::class) return false @@ -13,6 +40,8 @@ public data class DataBatch(val x: Array>, val y: Ten if (!x.contentEquals(other.x)) return false if (y != other.y) return false + if (!indices.contentEquals(other.indices)) return false + if (metadata != other.metadata) return false return true } @@ -20,6 +49,18 @@ public data class DataBatch(val x: Array>, val y: Ten override fun hashCode(): Int { var result = x.contentHashCode() result = 31 * result + y.hashCode() + result = 31 * result + indices.contentHashCode() + result = 31 * result + metadata.hashCode() return result } } + +private fun Tensor.sliceLeadingDimension(start: Int, endExclusive: Int): Tensor { + if (rank == 0) return this + return sliceView { + segment { range(start, endExclusive) } + repeat(rank - 1) { + segment { all() } + } + } +} diff --git a/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/Dataset.kt b/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/Dataset.kt index abb111a8..cc72cae7 100644 --- a/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/Dataset.kt +++ b/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/Dataset.kt @@ -1,17 +1,30 @@ package sk.ainet.data +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import sk.ainet.lang.tensor.Shape import sk.ainet.lang.types.DType import kotlin.math.min +import kotlin.random.Random /** Just abstract Dataset. */ public abstract class Dataset { + /** Optional shape metadata for one input sample. */ + public open val inputShape: Shape? get() = null + + /** Optional shape metadata for one output sample. */ + public open val outputShape: Shape? get() = null + /** Splits datasets on two sub-datasets according [splitRatio].*/ public abstract fun split(splitRatio: Double): Pair, Dataset> /** Returns amount of data rows. */ public abstract val xSize: Int + /** Returns amount of data rows. Alias for [xSize]. */ + public val size: Int get() = xSize + /** Returns row by index [idx]. */ public abstract fun getX(idx: Int): T @@ -21,6 +34,60 @@ public abstract class Dataset { /** Shuffles the dataset. */ public abstract fun shuffle(): Dataset + /** Shuffles the dataset, using [seed] when deterministic ordering is required. */ + public open fun shuffle(seed: Long? = null): Dataset { + if (seed == null) return shuffle() + val indices = IntArray(xSize) { it } + indices.shuffle(Random(seed)) + return IndexedDataset(this, indices) + } + + /** + * Splits the dataset with optional deterministic shuffling and label stratification. + * + * The original [split] remains the compatibility path. This overload adds the + * ML-oriented behavior expected by training pipelines without forcing every + * concrete dataset to duplicate the same index bookkeeping. + */ + public open fun split( + splitRatio: Double, + seed: Long? = null, + stratified: Boolean = false + ): Pair, Dataset> { + require(splitRatio > 0.0 && splitRatio < 1.0) { "splitRatio must be in (0,1)" } + if (seed == null && !stratified) return split(splitRatio) + + val (leftIndices, rightIndices) = if (stratified) { + stratifiedSplitIndices(splitRatio, seed) + } else { + val indices = IntArray(xSize) { it } + seed?.let { indices.shuffle(Random(it)) } + indices.splitAtRatio(splitRatio) + } + + return IndexedDataset(this, leftIndices) to IndexedDataset(this, rightIndices) + } + + /** Returns a dataset view containing only samples accepted by [predicate]. */ + public fun filter(predicate: (T, Y) -> Boolean): Dataset { + val indices = (0 until xSize) + .filter { idx -> predicate(getX(idx), getY(idx)) } + .toIntArray() + return IndexedDataset(this, indices) + } + + /** Returns a dataset view with transformed input samples. */ + public fun mapX(transform: (T) -> NX): Dataset = + MappedDataset(this) { x, y -> transform(x) to y } + + /** Returns a dataset view with transformed target samples. */ + public fun mapY(transform: (Y) -> NY): Dataset = + MappedDataset(this) { x, y -> x to transform(y) } + + /** Returns a dataset view with transformed input and target samples. */ + public fun transform(transformer: (T, Y) -> Pair): Dataset = + MappedDataset(this, transformer) + /** * An iterator over a [Dataset]. */ @@ -43,9 +110,164 @@ public abstract class Dataset { /** Creates data batch that starts from [batchStart] with length [batchLength]. */ protected abstract fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch + /** + * Creates a data batch for arbitrary logical sample [indices]. + * + * Concrete datasets that can tensorize non-contiguous samples should override + * this method. The default path supports contiguous ranges and fails fast for + * non-contiguous index views instead of silently returning the wrong rows. + */ + protected open fun createIndexedDataBatch(indices: IntArray): DataBatch { + require(indices.isNotEmpty()) { "indices must not be empty" } + val first = indices.first() + require(first >= 0) { "indices must be non-negative" } + val contiguous = indices.withIndex().all { (offset, value) -> value == first + offset } + require(contiguous) { + "Non-contiguous data batches require createIndexedDataBatch(indices) support in the concrete dataset" + } + return createDataBatch(first, indices.size) + } + + /** Creates data batch that starts from [batchStart] with length [batchLength]. */ + public fun dataBatch(batchStart: Int, batchLength: Int): DataBatch { + require(batchStart >= 0) { "batchStart must be non-negative" } + require(batchLength >= 0) { "batchLength must be non-negative" } + require(batchStart + batchLength <= xSize) { "batch exceeds dataset size" } + return createDataBatch(batchStart, batchLength) + } + + /** Creates a data batch for arbitrary logical sample [indices]. */ + public fun dataBatch(indices: IntArray): DataBatch { + require(indices.all { it in 0 until xSize }) { "indices must be inside dataset bounds" } + return createIndexedDataBatch(indices) + } /** Returns [BatchIterator] with fixed [batchSize]. */ public fun batchIterator(batchSize: Int): BatchIterator { + require(batchSize > 0) { "batchSize must be positive" } return BatchIterator(batchSize) } -} \ No newline at end of file + + /** Returns a cold [Flow] of data batches. */ + public fun batches( + batchSize: Int, + shuffle: Boolean = true, + seed: Long? = null + ): Flow> = flow { + val source = if (shuffle) shuffle(seed) else this@Dataset + val iterator = source.batchIterator(batchSize) + while (iterator.hasNext()) { + emit(iterator.next()) + } + } + + /** Returns a cold [Flow] over [epochCount] epochs of data batches. */ + public fun epochs( + epochCount: Int, + batchSize: Int, + shuffle: Boolean = true, + seed: Long? = null + ): Flow> = flow { + require(epochCount >= 0) { "epochCount must be non-negative" } + for (epoch in 0 until epochCount) { + val epochSeed = seed?.plus(epoch) + val iterator = (if (shuffle) shuffle(epochSeed) else this@Dataset).batchIterator(batchSize) + while (iterator.hasNext()) { + emit(iterator.next()) + } + } + } + + private fun stratifiedSplitIndices(splitRatio: Double, seed: Long?): Pair { + val buckets = LinkedHashMap>() + for (idx in 0 until xSize) { + buckets.getOrPut(getY(idx)) { mutableListOf() }.add(idx) + } + + val random = seed?.let { Random(it) } + val left = mutableListOf() + val right = mutableListOf() + + for (bucket in buckets.values) { + val indices = bucket.toMutableList() + if (random != null) indices.shuffle(random) + val splitIndex = (indices.size * splitRatio).toInt().coerceIn(0, indices.size) + left.addAll(indices.subList(0, splitIndex)) + right.addAll(indices.subList(splitIndex, indices.size)) + } + + if (random != null) { + left.shuffle(random) + right.shuffle(random) + } + + return left.toIntArray() to right.toIntArray() + } +} + +private class IndexedDataset( + private val source: Dataset, + private val indices: IntArray +) : Dataset() { + override val inputShape: Shape? get() = source.inputShape + override val outputShape: Shape? get() = source.outputShape + override val xSize: Int get() = indices.size + + override fun getX(idx: Int): X = source.getX(indices[idx]) + + override fun getY(idx: Int): Y = source.getY(indices[idx]) + + override fun shuffle(): Dataset { + val shuffled = indices.copyOf() + shuffled.shuffle(Random.Default) + return IndexedDataset(source, shuffled) + } + + override fun split(splitRatio: Double): Pair, Dataset> { + require(splitRatio > 0.0 && splitRatio < 1.0) { "splitRatio must be in (0,1)" } + val (left, right) = indices.splitAtRatio(splitRatio) + return IndexedDataset(source, left) to IndexedDataset(source, right) + } + + override fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch { + val actualLength = min(batchLength, xSize - batchStart) + val batchIndices = IntArray(actualLength) { offset -> indices[batchStart + offset] } + return source.dataBatch(batchIndices) + } + + override fun createIndexedDataBatch(indices: IntArray): DataBatch { + val sourceIndices = IntArray(indices.size) { offset -> this.indices[indices[offset]] } + return source.dataBatch(sourceIndices) + } +} + +private class MappedDataset( + private val source: Dataset, + private val transformer: (SX, SY) -> Pair +) : Dataset() { + override val xSize: Int get() = source.xSize + + override fun getX(idx: Int): TX { + val (x, _) = transformer(source.getX(idx), source.getY(idx)) + return x + } + + override fun getY(idx: Int): TY { + val (_, y) = transformer(source.getX(idx), source.getY(idx)) + return y + } + + override fun shuffle(): Dataset = shuffle(Random.nextLong()) + + override fun split(splitRatio: Double): Pair, Dataset> = + split(splitRatio, seed = null, stratified = false) + + override fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch { + throw UnsupportedOperationException("MappedDataset cannot create tensor batches without a tensorization transform") + } +} + +private fun IntArray.splitAtRatio(splitRatio: Double): Pair { + val splitIndex = (size * splitRatio).toInt().coerceIn(0, size) + return copyOfRange(0, splitIndex) to copyOfRange(splitIndex, size) +} diff --git a/skainet-data/skainet-data-api/src/commonTest/kotlin/sk/ainet/data/DatasetAndDataBatchTest.kt b/skainet-data/skainet-data-api/src/commonTest/kotlin/sk/ainet/data/DatasetAndDataBatchTest.kt index e8395b94..daec89b1 100644 --- a/skainet-data/skainet-data-api/src/commonTest/kotlin/sk/ainet/data/DatasetAndDataBatchTest.kt +++ b/skainet-data/skainet-data-api/src/commonTest/kotlin/sk/ainet/data/DatasetAndDataBatchTest.kt @@ -1,5 +1,7 @@ package sk.ainet.data +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertEquals import sk.ainet.context.DefaultDataExecutionContext @@ -43,15 +45,25 @@ private class FakeDataset( @Suppress("UNCHECKED_CAST") override fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch { val end = (batchStart + batchLength).coerceAtMost(features.size) - val sliceF = features.subList(batchStart, end) - val sliceY = labels.subList(batchStart, end) - val featureSize = if (sliceF.isNotEmpty()) sliceF[0].size else 0 + val indices = (batchStart until end).toList() + return createBatchForIndices(indices) as DataBatch + } + + @Suppress("UNCHECKED_CAST") + override fun createIndexedDataBatch(indices: IntArray): DataBatch { + return createBatchForIndices(indices.toList()) as DataBatch + } + + private fun createBatchForIndices(indices: List): DataBatch { + val sliceF = indices.map { features[it] } + val sliceY = indices.map { labels[it] } + val featureSize = sliceF.firstOrNull()?.size ?: 0 // Build a single input tensor [batch, feature] val xTensor: Tensor = tensor(ctx, FP32::class) { tensor { - shape(end - batchStart, featureSize) { - val flat = FloatArray((end - batchStart) * featureSize) + shape(indices.size, featureSize) { + val flat = FloatArray(indices.size * featureSize) var k = 0 for (i in sliceF.indices) { val row = sliceF[i] @@ -67,15 +79,14 @@ private class FakeDataset( // Build label tensor [batch] val yTensor: Tensor = tensor(ctx, FP32::class) { tensor { - shape(end - batchStart) { - val flat = FloatArray(end - batchStart) { idx -> sliceY[idx] } + shape(indices.size) { + val flat = FloatArray(indices.size) { idx -> sliceY[idx] } fromArray(flat) } } } - val batch = DataBatch(arrayOf(xTensor), yTensor) - return batch as DataBatch + return DataBatch(arrayOf(xTensor), yTensor) } } @@ -100,6 +111,50 @@ class DatasetAndDataBatchTest { assertNotEquals(batchA, batchC, "Batches with same values but different tensor instances should not be equal") } + @Test + fun dataBatchCarriesIndicesAndMetadata() { + val ctx = DefaultDataExecutionContext() + val x: Tensor = data(ctx) { tensor { shape(2, 2) { from(1f, 2f, 3f, 4f) } } } + val y: Tensor = tensor(ctx, FP32::class) { tensor { shape(2) { from(0f, 1f) } } } + + val batch = DataBatch( + x = arrayOf(x), + y = y, + indices = intArrayOf(10, 20), + metadata = mapOf("split" to "train") + ) + val enriched = batch.withMetadata(mapOf("epoch" to "1")) + + assertEquals(2, batch.batchSize) + assertEquals(mapOf("split" to "train", "epoch" to "1"), enriched.metadata) + assertNotEquals(batch, batch.copy(indices = intArrayOf(11, 20))) + } + + @Test + fun dataBatchSliceUsesLeadingDimension() { + val ctx = DefaultDataExecutionContext() + val x: Tensor = data(ctx) { + tensor { + shape(3, 2) { + from(1f, 2f, 3f, 4f, 5f, 6f) + } + } + } + val y: Tensor = tensor(ctx, FP32::class) { tensor { shape(3) { from(0f, 1f, 2f) } } } + val batch = DataBatch(arrayOf(x), y, indices = intArrayOf(4, 5, 6)) + + val sliced = batch.slice(1..2) + + assertEquals(2, sliced.batchSize) + assertEquals(listOf(5, 6), sliced.indices.toList()) + assertEquals(2, sliced.x[0].shape[0]) + assertEquals(2, sliced.y.shape[0]) + assertEquals(3f, sliced.x[0].data[0, 0]) + assertEquals(6f, sliced.x[0].data[1, 1]) + assertEquals(1f, sliced.y.data[0]) + assertEquals(2f, sliced.y.data[1]) + } + @Test fun batchIteratorProducesCorrectSlices() { val feats = listOf( @@ -168,4 +223,106 @@ class DatasetAndDataBatchTest { val b = (0 until shuffled.xSize).map { shuffled.getX(it).joinToString(",") }.sorted() assertEquals(a, b) } + + @Test + fun sizeAliasMatchesXSize() { + val ds = FakeDataset( + features = (1..3).map { i -> floatArrayOf(i.toFloat()) }, + labels = listOf(0f, 1f, 0f) + ) + + assertEquals(ds.xSize, ds.size) + } + + @Test + fun seededShuffleIsDeterministic() { + val ds = FakeDataset( + features = (1..12).map { i -> floatArrayOf(i.toFloat()) }, + labels = (1..12).map { (it % 2).toFloat() } + ) + + val first = ds.shuffle(seed = 42) + val second = ds.shuffle(seed = 42) + val third = ds.shuffle(seed = 99) + + val firstOrder = (0 until first.xSize).map { first.getX(it)[0] } + val secondOrder = (0 until second.xSize).map { second.getX(it)[0] } + val thirdOrder = (0 until third.xSize).map { third.getX(it)[0] } + + assertEquals(firstOrder, secondOrder) + assertNotEquals(firstOrder, thirdOrder) + } + + @Test + fun seededSplitIsDeterministic() { + val ds = FakeDataset( + features = (1..10).map { i -> floatArrayOf(i.toFloat()) }, + labels = (1..10).map { (it % 2).toFloat() } + ) + + val (trainA, testA) = ds.split(splitRatio = 0.7, seed = 123) + val (trainB, testB) = ds.split(splitRatio = 0.7, seed = 123) + + assertEquals((0 until trainA.xSize).map { trainA.getX(it)[0] }, (0 until trainB.xSize).map { trainB.getX(it)[0] }) + assertEquals((0 until testA.xSize).map { testA.getX(it)[0] }, (0 until testB.xSize).map { testB.getX(it)[0] }) + } + + @Test + fun stratifiedSplitKeepsBothLabelsInBothSides() { + val ds = FakeDataset( + features = (1..20).map { i -> floatArrayOf(i.toFloat()) }, + labels = (1..20).map { (it % 2).toFloat() } + ) + + val (train, test) = ds.split(splitRatio = 0.5, seed = 7, stratified = true) + + assertEquals(setOf(0f, 1f), (0 until train.xSize).map { train.getY(it) }.toSet()) + assertEquals(setOf(0f, 1f), (0 until test.xSize).map { test.getY(it) }.toSet()) + } + + @Test + fun filterAndMapCreateDatasetViews() { + val ds = FakeDataset( + features = (1..6).map { i -> floatArrayOf(i.toFloat()) }, + labels = (1..6).map { (it % 2).toFloat() } + ) + + val filtered = ds.filter { _, label -> label == 0f } + val mapped = filtered + .mapX { features -> features[0].toInt() } + .mapY { label -> "class-${label.toInt()}" } + + assertEquals(3, mapped.xSize) + assertEquals(listOf(2, 4, 6), (0 until mapped.xSize).map { mapped.getX(it) }) + assertEquals(listOf("class-0", "class-0", "class-0"), (0 until mapped.xSize).map { mapped.getY(it) }) + } + + @Test + fun batchesFlowEmitsAllBatches() = runTest { + val ds = FakeDataset( + features = (1..5).map { i -> floatArrayOf(i.toFloat(), (i * 10).toFloat()) }, + labels = (1..5).map { (it % 2).toFloat() } + ) + + val batches = ds.batches(batchSize = 2, shuffle = false).toList() + + assertEquals(3, batches.size) + assertEquals(2, batches[0].y.shape[0]) + assertEquals(2, batches[1].y.shape[0]) + assertEquals(1, batches[2].y.shape[0]) + } + + @Test + fun epochsFlowUsesSeededShufflePerEpoch() = runTest { + val ds = FakeDataset( + features = (1..6).map { i -> floatArrayOf(i.toFloat()) }, + labels = (1..6).map { (it % 2).toFloat() } + ) + + val batches = ds.epochs(epochCount = 2, batchSize = 3, shuffle = true, seed = 11).toList() + + assertEquals(4, batches.size) + assertEquals(3, batches[0].y.shape[0]) + assertEquals(3, batches[3].y.shape[0]) + } } diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt index 1c63b86f..134722df 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt @@ -102,10 +102,23 @@ public data class CIFAR10Dataset( */ override fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch { val actualLen = min(batchLength, xSize - batchStart) - val batchImages = images.subList(batchStart, batchStart + actualLen) + val batchIndices = IntArray(actualLen) { offset -> batchStart + offset } + return createDataBatchForIndices(batchIndices) + } + + /** + * Creates a DataBatch from arbitrary sample indices. This keeps shuffled and + * filtered dataset views compatible with tensor batching. + */ + override fun createIndexedDataBatch(indices: IntArray): DataBatch = + createDataBatchForIndices(indices) + + @Suppress("UNCHECKED_CAST") + private fun createDataBatchForIndices(indices: IntArray): DataBatch { + val batchImages = indices.map { images[it] } // Concatenate raw image bytes (no normalization) for memory efficiency - val xData = ByteArray(actualLen * CIFAR10Constants.IMAGE_BYTES) + val xData = ByteArray(indices.size * CIFAR10Constants.IMAGE_BYTES) var offset = 0 for (sample in batchImages) { val bytes = sample.image @@ -114,19 +127,18 @@ public data class CIFAR10Dataset( } // Shape as [batch, 3, 32, 32] (channel-first) - val xShape = Shape(actualLen, CIFAR10Constants.NUM_CHANNELS, CIFAR10Constants.IMAGE_SIZE, CIFAR10Constants.IMAGE_SIZE) + val xShape = Shape(indices.size, CIFAR10Constants.NUM_CHANNELS, CIFAR10Constants.IMAGE_SIZE, CIFAR10Constants.IMAGE_SIZE) val xTensor: Tensor = executionContext.fromByteArray(xShape, Int8::class, xData) // Labels as bytes (memory-efficient) - val yData = ByteArray(actualLen) { idx -> batchImages[idx].label } - val yShape = Shape(actualLen) + val yData = ByteArray(indices.size) { idx -> batchImages[idx].label } + val yShape = Shape(indices.size) val yTensor: Tensor = executionContext.fromByteArray(yShape, Int8::class, yData) // DataBatch expects array of input tensors; we provide single input val xArray: Array> = arrayOf(xTensor) - @Suppress("UNCHECKED_CAST") - return DataBatch(xArray as Array>, yTensor as Tensor) + return DataBatch(xArray as Array>, yTensor as Tensor, indices = indices.copyOf()) } /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetLoaderUnsupportedTargetException.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetLoaderUnsupportedTargetException.kt new file mode 100644 index 00000000..70137a49 --- /dev/null +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetLoaderUnsupportedTargetException.kt @@ -0,0 +1,20 @@ +package sk.ainet.data.common + +/** + * Raised when a built-in dataset loader is compiled for a target where the + * required transport, archive, or decompression primitive is not implemented. + */ +public class DatasetLoaderUnsupportedTargetException( + public val dataset: String, + public val target: String, + reason: String +) : UnsupportedOperationException("$dataset loader is not supported on $target: $reason") + +/** Throws a typed unsupported-target exception for built-in dataset loaders. */ +public fun unsupportedDatasetLoader(dataset: String, target: String, reason: String): Nothing { + throw DatasetLoaderUnsupportedTargetException(dataset, target, reason) +} + +/** Returns true when this byte array starts with the gzip magic header. */ +public fun ByteArray.hasGzipHeader(): Boolean = + size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt index c9286506..7ee50a34 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt @@ -104,10 +104,23 @@ public data class FashionMNISTDataset( */ override fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch { val actualLen = min(batchLength, xSize - batchStart) - val batchImages = images.subList(batchStart, batchStart + actualLen) + val batchIndices = IntArray(actualLen) { offset -> batchStart + offset } + return createDataBatchForIndices(batchIndices) + } + + /** + * Creates a DataBatch from arbitrary sample indices. This keeps shuffled and + * filtered dataset views compatible with tensor batching. + */ + override fun createIndexedDataBatch(indices: IntArray): DataBatch = + createDataBatchForIndices(indices) + + @Suppress("UNCHECKED_CAST") + private fun createDataBatchForIndices(indices: IntArray): DataBatch { + val batchImages = indices.map { images[it] } // Concatenate raw image bytes (no normalization) for memory efficiency - val xData = ByteArray(actualLen * FashionMNISTConstants.IMAGE_PIXELS) + val xData = ByteArray(indices.size * FashionMNISTConstants.IMAGE_PIXELS) var offset = 0 for (sample in batchImages) { val bytes = sample.image @@ -116,19 +129,18 @@ public data class FashionMNISTDataset( } // Shape as [batch, 1, 28, 28] - val xShape = Shape(actualLen, 1, FashionMNISTConstants.IMAGE_SIZE, FashionMNISTConstants.IMAGE_SIZE) + val xShape = Shape(indices.size, 1, FashionMNISTConstants.IMAGE_SIZE, FashionMNISTConstants.IMAGE_SIZE) val xTensor: Tensor = executionContext.fromByteArray(xShape, Int8::class, xData) // Labels as bytes (memory-efficient) - val yData = ByteArray(actualLen) { idx -> batchImages[idx].label } - val yShape = Shape(actualLen) + val yData = ByteArray(indices.size) { idx -> batchImages[idx].label } + val yShape = Shape(indices.size) val yTensor: Tensor = executionContext.fromByteArray(yShape, Int8::class, yData) // DataBatch expects array of input tensors; we provide single input val xArray: Array> = arrayOf(xTensor) - @Suppress("UNCHECKED_CAST") - return DataBatch(xArray as Array>, yTensor as Tensor) + return DataBatch(xArray as Array>, yTensor as Tensor, indices = indices.copyOf()) } /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt index ae6881cf..a4875219 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt @@ -82,10 +82,23 @@ public data class MNISTDataset( */ override fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch { val actualLen = min(batchLength, xSize - batchStart) - val batchImages = images.subList(batchStart, batchStart + actualLen) + val batchIndices = IntArray(actualLen) { offset -> batchStart + offset } + return createDataBatchForIndices(batchIndices) + } + + /** + * Creates a DataBatch from arbitrary sample indices. This keeps shuffled and + * filtered dataset views compatible with tensor batching. + */ + override fun createIndexedDataBatch(indices: IntArray): DataBatch = + createDataBatchForIndices(indices) + + @Suppress("UNCHECKED_CAST") + private fun createDataBatchForIndices(indices: IntArray): DataBatch { + val batchImages = indices.map { images[it] } // Concatenate raw image bytes (no normalization) for memory efficiency - val xData = ByteArray(actualLen * MNISTConstants.IMAGE_PIXELS) + val xData = ByteArray(indices.size * MNISTConstants.IMAGE_PIXELS) var offset = 0 for (sample in batchImages) { val bytes = sample.image @@ -94,19 +107,18 @@ public data class MNISTDataset( } // Shape as [batch, 1, 28, 28] - val xShape = Shape(actualLen, 1, MNISTConstants.IMAGE_SIZE, MNISTConstants.IMAGE_SIZE) + val xShape = Shape(indices.size, 1, MNISTConstants.IMAGE_SIZE, MNISTConstants.IMAGE_SIZE) val xTensor: Tensor = executionContext.fromByteArray(xShape, Int8::class, xData) // Labels as bytes (memory-efficient). Keep as Int8 to satisfy DataBatch single dtype requirement - val yData = ByteArray(actualLen) { idx -> batchImages[idx].label } - val yShape = Shape(actualLen) + val yData = ByteArray(indices.size) { idx -> batchImages[idx].label } + val yShape = Shape(indices.size) val yTensor: Tensor = executionContext.fromByteArray(yShape, Int8::class, yData) // DataBatch expects array of input tensors; we provide single input val xArray: Array> = arrayOf(xTensor) - @Suppress("UNCHECKED_CAST") - return DataBatch(xArray as Array>, yTensor as Tensor) + return DataBatch(xArray as Array>, yTensor as Tensor, indices = indices.copyOf()) } /** diff --git a/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderIos.kt b/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderIos.kt index 88384f31..7d0a24a5 100644 --- a/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderIos.kt +++ b/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderIos.kt @@ -2,6 +2,7 @@ package sk.ainet.data.cifar10 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * iOS implementation of the CIFAR-10 loader. @@ -13,7 +14,11 @@ public class CIFAR10LoaderIos(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon override suspend fun downloadAndExtractBatch(batchFilename: String): ByteArray = withContext(Dispatchers.Default) { - error("CIFAR10LoaderIos is not fully implemented yet. Tar.gz extraction requires native implementation.") + unsupportedDatasetLoader( + dataset = "CIFAR-10", + target = "ios", + reason = "tar.gz extraction is not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderIos.kt b/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderIos.kt index ea97291e..98d96341 100644 --- a/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderIos.kt +++ b/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderIos.kt @@ -7,6 +7,8 @@ import io.ktor.client.request.get import io.ktor.client.statement.HttpResponse import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.hasGzipHeader +import sk.ainet.data.common.unsupportedDatasetLoader /** * iOS implementation of the Fashion-MNIST loader. @@ -29,9 +31,13 @@ public class FashionMNISTLoaderIos(config: FashionMNISTLoaderConfig) : FashionMN println("Downloading Fashion-MNIST file: $url") val data = downloadFile(url) - // Note: In a real implementation, we would use a native gzip library to decompress the data - // For this example, we're assuming the server provides uncompressed data for iOS clients - println("iOS implementation does not support gzip decompression in this example. Assuming data is already decompressed.") + if (data.hasGzipHeader()) { + unsupportedDatasetLoader( + dataset = "Fashion-MNIST", + target = "ios", + reason = "gzip decompression is not implemented; provide an uncompressed IDX URI" + ) + } return@withContext data } diff --git a/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderIos.kt b/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderIos.kt index d39fb55b..c422c9bd 100644 --- a/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderIos.kt +++ b/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderIos.kt @@ -7,6 +7,8 @@ import io.ktor.client.request.get import io.ktor.client.statement.HttpResponse import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.hasGzipHeader +import sk.ainet.data.common.unsupportedDatasetLoader /** * iOS implementation of the MNIST loader. @@ -29,9 +31,13 @@ public class MNISTLoaderIos(config: MNISTLoaderConfig) : MNISTLoaderCommon(confi println("Downloading file: $url") val data = downloadFile(url) - // Note: In a real implementation, we would use a native gzip library to decompress the data - // For this example, we're assuming the server provides uncompressed data for iOS clients - println("iOS implementation does not support gzip decompression in this example. Assuming data is already decompressed.") + if (data.hasGzipHeader()) { + unsupportedDatasetLoader( + dataset = "MNIST", + target = "ios", + reason = "gzip decompression is not implemented; provide an uncompressed IDX URI" + ) + } return@withContext data } @@ -85,4 +91,4 @@ public class MNISTLoaderIos(config: MNISTLoaderConfig) : MNISTLoaderCommon(confi return MNISTLoaderIos(config) } } -} \ No newline at end of file +} diff --git a/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJs.kt b/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJs.kt index efe9ca1c..3e1fe62f 100644 --- a/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJs.kt +++ b/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJs.kt @@ -2,6 +2,7 @@ package sk.ainet.data.cifar10 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * JS (browser) implementation of the CIFAR-10 loader. @@ -13,7 +14,11 @@ public class CIFAR10LoaderJs(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon( override suspend fun downloadAndExtractBatch(batchFilename: String): ByteArray = withContext(Dispatchers.Default) { - error("CIFAR10LoaderJs is not fully implemented yet. Tar.gz extraction requires JS library support.") + unsupportedDatasetLoader( + dataset = "CIFAR-10", + target = "js", + reason = "tar.gz extraction is not implemented for this browser target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJs.kt b/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJs.kt index d461e02f..114dc37d 100644 --- a/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJs.kt +++ b/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJs.kt @@ -7,9 +7,13 @@ import io.ktor.client.statement.HttpResponse import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import kotlin.js.ExperimentalWasmJsInterop import kotlin.js.Promise import kotlinx.coroutines.await +import sk.ainet.data.common.hasGzipHeader +import sk.ainet.data.common.unsupportedDatasetLoader +@OptIn(ExperimentalWasmJsInterop::class) @JsFun( """ async function(input) { @@ -45,7 +49,13 @@ public class FashionMNISTLoaderJs(config: FashionMNISTLoaderConfig) : FashionMNI if (decompressed != null) { decompressed } else { - println("[FashionMNIST][JS] DecompressionStream not available. Returning raw data (likely gzip) which will fail to parse.") + if (gzData.hasGzipHeader()) { + unsupportedDatasetLoader( + dataset = "Fashion-MNIST", + target = "js", + reason = "browser DecompressionStream is unavailable; provide an uncompressed IDX URI" + ) + } gzData } } diff --git a/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJs.kt b/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJs.kt index 797994cc..3c9051b1 100644 --- a/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJs.kt +++ b/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJs.kt @@ -7,9 +7,13 @@ import io.ktor.client.statement.HttpResponse import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import kotlin.js.ExperimentalWasmJsInterop import kotlin.js.Promise import kotlinx.coroutines.await +import sk.ainet.data.common.hasGzipHeader +import sk.ainet.data.common.unsupportedDatasetLoader +@OptIn(ExperimentalWasmJsInterop::class) @JsFun( """ async function(input) { @@ -45,7 +49,13 @@ public class MNISTLoaderJs(config: MNISTLoaderConfig) : MNISTLoaderCommon(config if (decompressed != null) { decompressed } else { - println("[MNIST][JS] DecompressionStream not available. Returning raw data (likely gzip) which will fail to parse.") + if (gzData.hasGzipHeader()) { + unsupportedDatasetLoader( + dataset = "MNIST", + target = "js", + reason = "browser DecompressionStream is unavailable; provide an uncompressed IDX URI" + ) + } gzData } } diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt index 541516b3..9d4d363f 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt @@ -6,6 +6,7 @@ import sk.ainet.data.cifar10.CIFAR10Image import sk.ainet.data.cifar10.CIFAR10LoaderConfig import sk.ainet.data.cifar10.CIFAR10LoaderCommon import sk.ainet.data.cifar10.createCIFAR10Loader +import sk.ainet.lang.types.Int8 import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNotNull @@ -56,6 +57,19 @@ class CIFAR10LoaderTest { assertEquals(dataset.images[1], subset.images[1]) } + @Test + fun testShuffledDatasetViewCanCreateBatch() = runBlocking { + val dataset = createFakeLoader().loadTrainingData() + val shuffled = dataset.shuffle(seed = 123) + + val batch = shuffled.batchIterator(4).next() + + assertEquals(4, batch.batchSize) + assertEquals(4, batch.indices.size) + assertEquals(4, batch.x[0].shape[0]) + assertEquals(4, batch.y.shape[0]) + } + @Test fun testLoaderConfiguration() { val config = CIFAR10LoaderConfig( diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt index 30135890..2d11cb59 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt @@ -6,6 +6,7 @@ import sk.ainet.data.fashionmnist.FashionMNISTImage import sk.ainet.data.fashionmnist.FashionMNISTLoaderConfig import sk.ainet.data.fashionmnist.FashionMNISTLoaderCommon import sk.ainet.data.fashionmnist.createFashionMNISTLoader +import sk.ainet.lang.types.Int8 import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNotNull @@ -56,6 +57,19 @@ class FashionMNISTLoaderTest { assertEquals(dataset.images[1], subset.images[1]) } + @Test + fun testShuffledDatasetViewCanCreateBatch() = runBlocking { + val dataset = createFakeLoader().loadTrainingData() + val shuffled = dataset.shuffle(seed = 123) + + val batch = shuffled.batchIterator(2).next() + + assertEquals(2, batch.batchSize) + assertEquals(2, batch.indices.size) + assertEquals(2, batch.x[0].shape[0]) + assertEquals(2, batch.y.shape[0]) + } + @Test fun testLoaderConfiguration() { val config = FashionMNISTLoaderConfig( diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt index 94796645..1de5fc26 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt @@ -6,6 +6,7 @@ import sk.ainet.data.mnist.MNISTImage import sk.ainet.data.mnist.MNISTLoaderConfig import sk.ainet.data.mnist.MNISTLoaderFactory import sk.ainet.data.mnist.MNISTLoaderCommon +import sk.ainet.lang.types.Int8 import java.nio.file.Files import kotlin.test.Test import kotlin.test.assertEquals @@ -57,6 +58,19 @@ class MNISTLoaderTest { assertEquals(dataset.images[1], subset.images[1]) } + @Test + fun testShuffledDatasetViewCanCreateBatch() = runBlocking { + val dataset = createFakeLoader().loadTrainingData() + val shuffled = dataset.shuffle(seed = 123) + + val batch = shuffled.batchIterator(2).next() + + assertEquals(2, batch.batchSize) + assertEquals(2, batch.indices.size) + assertEquals(2, batch.x[0].shape[0]) + assertEquals(2, batch.y.shape[0]) + } + @Test fun testLoaderConfiguration() { val config = MNISTLoaderConfig( diff --git a/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderLinux.kt b/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderLinux.kt index eb363198..3f06c922 100644 --- a/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderLinux.kt +++ b/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderLinux.kt @@ -2,6 +2,7 @@ package sk.ainet.data.cifar10 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * Linux implementation of the CIFAR-10 loader. @@ -13,7 +14,11 @@ public class CIFAR10LoaderLinux(config: CIFAR10LoaderConfig) : CIFAR10LoaderComm override suspend fun downloadAndExtractBatch(batchFilename: String): ByteArray = withContext(Dispatchers.Default) { - error("CIFAR10LoaderLinux is not implemented yet.") + unsupportedDatasetLoader( + dataset = "CIFAR-10", + target = "linux", + reason = "tar.gz extraction is not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderLinux.kt b/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderLinux.kt index 984cad46..ff1cf0f6 100644 --- a/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderLinux.kt +++ b/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderLinux.kt @@ -2,6 +2,7 @@ package sk.ainet.data.fashionmnist import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * Linux implementation of the Fashion-MNIST loader. @@ -14,7 +15,11 @@ public class FashionMNISTLoaderLinux(config: FashionMNISTLoaderConfig) : Fashion */ override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.Default) { - error("FashionMNISTLoaderLinux.downloadAndCacheFile is not implemented yet.") + unsupportedDatasetLoader( + dataset = "Fashion-MNIST", + target = "linux", + reason = "gzip decompression and cache materialization are not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/mnist/MNISTLoaderLinux.kt b/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/mnist/MNISTLoaderLinux.kt index 11edb857..046219c3 100644 --- a/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/mnist/MNISTLoaderLinux.kt +++ b/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/mnist/MNISTLoaderLinux.kt @@ -2,6 +2,7 @@ package sk.ainet.data.mnist import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * Linux implementation of the MNIST loader. @@ -14,7 +15,11 @@ public class MNISTLoaderLinux(config: MNISTLoaderConfig) : MNISTLoaderCommon(con */ override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.Default) { - error("MNISTLoaderLinux.downloadAndCacheFile is not implemented yet.") + unsupportedDatasetLoader( + dataset = "MNIST", + target = "linux", + reason = "gzip decompression and cache materialization are not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderMacos.kt b/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderMacos.kt index 3d1eab89..45839246 100644 --- a/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderMacos.kt +++ b/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderMacos.kt @@ -2,6 +2,7 @@ package sk.ainet.data.cifar10 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * macOS implementation of the CIFAR-10 loader. @@ -13,7 +14,11 @@ public class CIFAR10LoaderMacos(config: CIFAR10LoaderConfig) : CIFAR10LoaderComm override suspend fun downloadAndExtractBatch(batchFilename: String): ByteArray = withContext(Dispatchers.Default) { - error("CIFAR10LoaderMacos is not implemented yet.") + unsupportedDatasetLoader( + dataset = "CIFAR-10", + target = "macos", + reason = "tar.gz extraction is not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderMacos.kt b/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderMacos.kt index e57d5809..5e6938d1 100644 --- a/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderMacos.kt +++ b/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderMacos.kt @@ -2,6 +2,7 @@ package sk.ainet.data.fashionmnist import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * macOS implementation of the Fashion-MNIST loader. @@ -14,7 +15,11 @@ public class FashionMNISTLoaderMacos(config: FashionMNISTLoaderConfig) : Fashion */ override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.Default) { - error("FashionMNISTLoaderMacos.downloadAndCacheFile is not implemented yet. Provide an appleMain implementation or avoid macOS usage for now.") + unsupportedDatasetLoader( + dataset = "Fashion-MNIST", + target = "macos", + reason = "gzip decompression and cache materialization are not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderMacos.kt b/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderMacos.kt index ca976d5f..acfc80ac 100644 --- a/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderMacos.kt +++ b/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderMacos.kt @@ -2,6 +2,7 @@ package sk.ainet.data.mnist import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * macOS implementation of the MNIST loader. @@ -14,7 +15,11 @@ public class MNISTLoaderMacos(config: MNISTLoaderConfig) : MNISTLoaderCommon(con */ override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.Default) { - error("MNISTLoaderMacos.downloadAndCacheFile is not implemented yet. Provide an appleMain implementation or avoid macOS usage for now.") + unsupportedDatasetLoader( + dataset = "MNIST", + target = "macos", + reason = "gzip decompression and cache materialization are not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderWasmJs.kt b/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderWasmJs.kt index 85b4648f..2ac5a0f7 100644 --- a/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderWasmJs.kt +++ b/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderWasmJs.kt @@ -2,6 +2,7 @@ package sk.ainet.data.cifar10 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * WASM JS implementation of the CIFAR-10 loader. @@ -13,7 +14,11 @@ public class CIFAR10LoaderWasmJs(config: CIFAR10LoaderConfig) : CIFAR10LoaderCom override suspend fun downloadAndExtractBatch(batchFilename: String): ByteArray = withContext(Dispatchers.Default) { - error("CIFAR10LoaderWasmJs is not implemented yet.") + unsupportedDatasetLoader( + dataset = "CIFAR-10", + target = "wasmJs", + reason = "tar.gz extraction is not implemented for this browser target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderWasmJs.kt b/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderWasmJs.kt index beab29b4..3af47dcc 100644 --- a/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderWasmJs.kt +++ b/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderWasmJs.kt @@ -7,6 +7,8 @@ import io.ktor.client.statement.HttpResponse import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.hasGzipHeader +import sk.ainet.data.common.unsupportedDatasetLoader /** * WASM JS implementation of the Fashion-MNIST loader. @@ -29,9 +31,13 @@ public class FashionMNISTLoaderWasmJs(config: FashionMNISTLoaderConfig) : Fashio println("Downloading Fashion-MNIST file: $url") val data = downloadFile(url) - // Note: In a real implementation, we would use a JS gzip library to decompress the data - // For this example, we're assuming the server provides uncompressed data for WASM clients - println("WASM JS implementation does not support gzip decompression. Assuming data is already decompressed.") + if (data.hasGzipHeader()) { + unsupportedDatasetLoader( + dataset = "Fashion-MNIST", + target = "wasmJs", + reason = "gzip decompression is not implemented; provide an uncompressed IDX URI" + ) + } return@withContext data } diff --git a/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderWasmJs.kt b/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderWasmJs.kt index 3886bdff..78b6bbe7 100644 --- a/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderWasmJs.kt +++ b/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderWasmJs.kt @@ -7,6 +7,8 @@ import io.ktor.client.statement.HttpResponse import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.hasGzipHeader +import sk.ainet.data.common.unsupportedDatasetLoader /** * WASM JS implementation of the MNIST loader. @@ -29,9 +31,13 @@ public class MNISTLoaderWasmJs(config: MNISTLoaderConfig) : MNISTLoaderCommon(co println("Downloading file: $url") val data = downloadFile(url) - // Note: In a real implementation, we would use a JS gzip library to decompress the data - // For this example, we're assuming the server provides uncompressed data for WASM clients - println("WASM JS implementation does not support gzip decompression. Assuming data is already decompressed.") + if (data.hasGzipHeader()) { + unsupportedDatasetLoader( + dataset = "MNIST", + target = "wasmJs", + reason = "gzip decompression is not implemented; provide an uncompressed IDX URI" + ) + } return@withContext data } diff --git a/skainet-data/skainet-data-source/build.gradle.kts b/skainet-data/skainet-data-source/build.gradle.kts index 7d7c3065..bcd69b4e 100644 --- a/skainet-data/skainet-data-source/build.gradle.kts +++ b/skainet-data/skainet-data-source/build.gradle.kts @@ -19,6 +19,7 @@ kotlin { commonMain.dependencies { implementation(libs.kotlinx.coroutines) implementation(libs.kotlinx.io.core) + implementation(libs.kotlinx.serialization.json) } commonTest.dependencies { diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt new file mode 100644 index 00000000..d6bf7ec0 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt @@ -0,0 +1,211 @@ +package sk.ainet.data.source + +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive + +/** Supported built-in raw dataset formats. */ +public enum class DataFormat(public val extensions: Set) { + CSV(setOf("csv")), + TSV(setOf("tsv")), + JSON(setOf("json")), + JSON_LINES(setOf("jsonl", "ndjson")); + + public companion object { + public fun fromExtension(extension: String): DataFormat? { + val normalized = extension.trim().lowercase().removePrefix(".") + if (normalized.isBlank()) return null + return entries.firstOrNull { format -> normalized in format.extensions } + } + + public fun inferFromFilename(filename: String): DataFormat? { + val normalized = filename + .substringBefore('?') + .substringBefore('#') + .trimEnd('/') + .substringAfterLast('/') + val extension = normalized.substringAfterLast('.', missingDelimiterValue = "") + return fromExtension(extension) + } + } +} + +/** A simple schema inferred from raw stringly parsed data. */ +public data class DataSchema( + public val columns: List +) + +/** One parsed row from a raw tabular or JSON-lines dataset. */ +public data class RawDataRow( + public val values: Map +) + +/** Parsed raw dataset plus schema and lightweight provenance metadata. */ +public data class RawDataset( + public val rows: List, + public val schema: DataSchema, + public val metadata: Map = emptyMap() +) + +/** Parser contract for converting source text into a [RawDataset]. */ +public interface DataFormatParser { + public val format: DataFormat + + public fun parse(text: String): RawDataset +} + +/** Registry for built-in and user-provided data format parsers. */ +public class DataFormatParserRegistry( + parsers: Iterable = defaultDataFormatParsers() +) { + private val parsersByFormat: MutableMap = + parsers.associateBy { it.format }.toMutableMap() + + public fun register(parser: DataFormatParser) { + parsersByFormat[parser.format] = parser + } + + public fun parserFor(format: DataFormat): DataFormatParser = + parsersByFormat[format] ?: throw DataSourceException("No parser registered for $format") + + public fun parse(format: DataFormat, text: String): RawDataset = + parserFor(format).parse(text) + + public companion object { + public fun default(): DataFormatParserRegistry = DataFormatParserRegistry() + } +} + +/** Returns the default built-in parser set. */ +public fun defaultDataFormatParsers(): List = + listOf( + DelimitedTextDataFormatParser(DataFormat.CSV, delimiter = ','), + DelimitedTextDataFormatParser(DataFormat.TSV, delimiter = '\t'), + JsonDataFormatParser(), + JsonLinesDataFormatParser() + ) + +/** Parser for simple delimited text with a required header row. */ +public class DelimitedTextDataFormatParser( + override val format: DataFormat, + private val delimiter: Char +) : DataFormatParser { + override fun parse(text: String): RawDataset { + val lines = text.lineSequence() + .map { it.trimEnd('\r') } + .filter { it.isNotBlank() } + .toList() + require(lines.isNotEmpty()) { "$format input must contain a header row" } + + val columns = splitDelimitedLine(lines.first(), delimiter) + require(columns.all { it.isNotBlank() }) { "$format header columns must not be blank" } + + val rows = lines.drop(1).mapIndexed { index, line -> + val values = splitDelimitedLine(line, delimiter) + require(values.size == columns.size) { + "$format row ${index + 2} has ${values.size} values but expected ${columns.size}" + } + RawDataRow(columns.zip(values).toMap()) + } + + return RawDataset( + rows = rows, + schema = DataSchema(columns), + metadata = mapOf("format" to format.name, "rowCount" to rows.size.toString()) + ) + } +} + +/** Parser for JSON object datasets encoded as one object or an array of objects. */ +public class JsonDataFormatParser( + private val json: Json = Json +) : DataFormatParser { + override val format: DataFormat = DataFormat.JSON + + override fun parse(text: String): RawDataset { + val root = json.parseToJsonElement(text) + val objects = when (root) { + is JsonObject -> listOf(root) + is JsonArray -> root.mapIndexed { index, element -> + require(element is JsonObject) { "JSON array element ${index + 1} must be an object" } + element + } + else -> throw IllegalArgumentException("JSON input must be an object or array of objects") + } + + return objects.toRawDataset(format) + } +} + +/** Parser for newline-delimited JSON objects. */ +public class JsonLinesDataFormatParser( + private val json: Json = Json +) : DataFormatParser { + override val format: DataFormat = DataFormat.JSON_LINES + + override fun parse(text: String): RawDataset { + val objects = text.lineSequence() + .map { it.trim() } + .filter { it.isNotEmpty() } + .mapIndexed { index, line -> + val element = json.parseToJsonElement(line) + require(element is JsonObject) { "JSON_LINES row ${index + 1} must be a JSON object" } + element + } + .toList() + + return objects.toRawDataset(format) + } +} + +private fun List.toRawDataset(format: DataFormat): RawDataset { + val columns = flatMap { it.keys }.distinct() + val rows = map { obj -> + RawDataRow(columns.associateWith { column -> obj[column]?.toRawString().orEmpty() }) + } + + return RawDataset( + rows = rows, + schema = DataSchema(columns), + metadata = mapOf("format" to format.name, "rowCount" to rows.size.toString()) + ) +} + +private fun splitDelimitedLine(line: String, delimiter: Char): List { + val values = mutableListOf() + val current = StringBuilder() + var quoted = false + var index = 0 + + while (index < line.length) { + val char = line[index] + when { + char == '"' && quoted && index + 1 < line.length && line[index + 1] == '"' -> { + current.append('"') + index++ + } + char == '"' -> quoted = !quoted + char == delimiter && !quoted -> { + values.add(current.toString()) + current.clear() + } + else -> current.append(char) + } + index++ + } + + require(!quoted) { "Unterminated quoted field" } + values.add(current.toString()) + return values +} + +private fun JsonElement.toRawString(): String = + when (this) { + JsonNull -> "" + is JsonPrimitive -> content + is JsonArray -> toString() + is JsonObject -> toString() + } diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataPipeline.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataPipeline.kt new file mode 100644 index 00000000..82b26df1 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataPipeline.kt @@ -0,0 +1,111 @@ +package sk.ainet.data.source + +/** A named suspendable data processing stage. */ +public interface PipelineStage { + public val name: String + + public fun validate(input: I): Boolean = true + + public suspend fun process(input: I): O +} + +/** A schema-aware stage for data preprocessing pipelines. */ +public interface DataTransformer : PipelineStage { + public suspend fun transform(input: I): O + + public fun getOutputSchema(inputSchema: DataSchema): DataSchema = inputSchema + + override suspend fun process(input: I): O = transform(input) +} + +/** Thrown when a data pipeline cannot execute a stage. */ +public class DataPipelineException( + message: String, + cause: Throwable? = null +) : DataSourceException(message, cause) + +/** A type-safe sequential data pipeline. */ +public class DataPipeline internal constructor( + private val stages: List> +) { + public val stageNames: List = stages.map { stage -> stage.name } + + /** Adds [stage] to the end of this pipeline. */ + public fun stage(stage: PipelineStage): DataPipeline = + DataPipeline(stages + stage) + + /** Adds [stage] to the end of this pipeline. */ + public infix fun then(stage: PipelineStage): DataPipeline = + this.stage(stage) + + /** Executes each stage in order. */ + @Suppress("UNCHECKED_CAST") + public suspend fun execute(input: I): O { + var current: Any? = input + for (stage in stages) { + val typedStage = stage as PipelineStage + if (!typedStage.validate(current)) { + throw DataPipelineException("Stage '${stage.name}' rejected its input") + } + current = typedStage.process(current) + } + return current as O + } + + /** Returns a human-readable stage chain. */ + public fun describe(): String = stageNames.joinToString(" -> ") +} + +/** Starts an identity data pipeline. */ +public fun dataPipeline(): DataPipeline = DataPipeline(emptyList()) + +/** Creates a named suspendable pipeline stage. */ +public fun pipelineStage( + name: String, + validate: (I) -> Boolean = { true }, + process: suspend (I) -> O +): PipelineStage { + require(name.isNotBlank()) { "Pipeline stage name must not be blank" } + return FunctionPipelineStage(name, validate, process) +} + +/** Kotlinish alias for [pipelineStage]. */ +public fun stage( + name: String, + validate: (I) -> Boolean = { true }, + process: suspend (I) -> O +): PipelineStage = pipelineStage(name, validate, process) + +/** Creates a named schema-aware data transformer. */ +public fun dataTransformer( + name: String, + outputSchema: (DataSchema) -> DataSchema = { it }, + validate: (I) -> Boolean = { true }, + transform: suspend (I) -> O +): DataTransformer { + require(name.isNotBlank()) { "Data transformer name must not be blank" } + return FunctionDataTransformer(name, validate, outputSchema, transform) +} + +private class FunctionPipelineStage( + override val name: String, + private val validateInput: (I) -> Boolean, + private val processInput: suspend (I) -> O +) : PipelineStage { + override fun validate(input: I): Boolean = validateInput(input) + + override suspend fun process(input: I): O = processInput(input) +} + +private class FunctionDataTransformer( + override val name: String, + private val validateInput: (I) -> Boolean, + private val outputSchema: (DataSchema) -> DataSchema, + private val transformInput: suspend (I) -> O +) : DataTransformer { + override fun validate(input: I): Boolean = validateInput(input) + + override suspend fun transform(input: I): O = transformInput(input) + + override fun getOutputSchema(inputSchema: DataSchema): DataSchema = outputSchema(inputSchema) +} diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceDatasetBuilder.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceDatasetBuilder.kt new file mode 100644 index 00000000..0e7b1bd7 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceDatasetBuilder.kt @@ -0,0 +1,132 @@ +package sk.ainet.data.source + +/** + * Kotlin DSL builder for resolving a data source artifact and parsing it into + * a raw, schema-bearing dataset. + */ +public class DataSourceDatasetBuilder( + resolver: DataSourceResolver? = null, + parserRegistry: DataFormatParserRegistry = DataFormatParserRegistry.default() +) { + private var resolver: DataSourceResolver? = resolver + private var parserRegistry: DataFormatParserRegistry = parserRegistry + private var sourceUri: String? = null + private var requestedFormat: DataFormat? = null + private var cachePolicy: CachePolicy = CachePolicy.Use + private var expectedSha256: String? = null + private val requestHeaders: MutableMap = linkedMapOf() + private var huggingFaceToken: DataSourceAuthToken? = null + + /** Selects the source artifact URI or local path. */ + public fun from(uri: String): DataSourceDatasetBuilder = apply { + val normalized = uri.trim() + require(normalized.isNotEmpty()) { "Data source URI must not be blank" } + sourceUri = normalized + } + + /** Overrides format inference from the source filename. */ + public fun format(format: DataFormat): DataSourceDatasetBuilder = apply { + requestedFormat = format + } + + /** Sets resolver cache behavior for this load. */ + public fun cachePolicy(cachePolicy: CachePolicy): DataSourceDatasetBuilder = apply { + this.cachePolicy = cachePolicy + } + + /** Sets an optional expected SHA-256 checksum for resolver verification. */ + public fun expectedSha256(value: String?): DataSourceDatasetBuilder = apply { + expectedSha256 = value?.trim()?.takeIf { it.isNotEmpty() } + } + + /** Adds one request header. */ + public fun header(name: String, value: String): DataSourceDatasetBuilder = apply { + val normalizedName = name.trim() + require(normalizedName.isNotEmpty()) { "Header name must not be blank" } + requestHeaders[normalizedName] = value + } + + /** Adds request headers, replacing entries with the same name. */ + public fun headers(headers: Map): DataSourceDatasetBuilder = apply { + headers.forEach { (name, value) -> header(name, value) } + } + + /** Sets a provider-specific Hugging Face token for this request. */ + public fun huggingFaceToken(token: DataSourceAuthToken?): DataSourceDatasetBuilder = apply { + huggingFaceToken = token + } + + /** Sets a provider-specific Hugging Face token for this request. */ + public fun huggingFaceToken(value: String): DataSourceDatasetBuilder = + huggingFaceToken(DataSourceAuthToken.from(value)) + + /** Supplies the source resolver used to materialize the artifact. */ + public fun resolver(resolver: DataSourceResolver): DataSourceDatasetBuilder = apply { + this.resolver = resolver + } + + /** Replaces the parser registry used by this builder. */ + public fun parserRegistry(parserRegistry: DataFormatParserRegistry): DataSourceDatasetBuilder = apply { + this.parserRegistry = parserRegistry + } + + /** Registers or replaces one parser in this builder's parser registry. */ + public fun parser(parser: DataFormatParser): DataSourceDatasetBuilder = apply { + parserRegistry.register(parser) + } + + /** Resolves and parses the configured source artifact. */ + public suspend fun build(): RawDataset { + val request = buildRequest() + val artifact = requireResolver().resolve(request) + val format = requestedFormat + ?: DataFormat.inferFromFilename(artifact.filename) + ?: throw DataSourceException( + "Cannot infer data format from '${artifact.filename}'. Specify format(...) explicitly." + ) + val text = artifact.readBytes().decodeToString() + return parserRegistry.parse(format, text).withSourceMetadata(artifact) + } + + private fun requireResolver(): DataSourceResolver = + resolver ?: throw DataSourceException("A DataSourceResolver is required to load a data source dataset") + + private fun buildRequest(): DataSourceRequest { + val uri = sourceUri ?: throw DataSourceException("Data source URI is required; call from(...) first") + return DataSourceRequest( + uri = uri, + cachePolicy = cachePolicy, + expectedSha256 = expectedSha256, + headers = requestHeaders.toMap(), + huggingFaceToken = huggingFaceToken + ) + } + + private fun RawDataset.withSourceMetadata(artifact: DataSourceArtifact): RawDataset { + val sourceMetadata = mutableMapOf( + "sourceUri" to artifact.request.uri, + "sourceProvider" to artifact.parsedUri.provider.name, + "sourceFilename" to artifact.filename, + "sourceCacheHit" to artifact.cacheHit.toString() + ) + artifact.localPath?.let { sourceMetadata["sourceLocalPath"] = it } + artifact.sizeBytes?.let { sourceMetadata["sourceSizeBytes"] = it.toString() } + return copy(metadata = metadata + sourceMetadata) + } +} + +/** Builds a raw dataset by resolving and parsing a configured data source. */ +public suspend fun rawDataset(block: DataSourceDatasetBuilder.() -> Unit): RawDataset = + DataSourceDatasetBuilder().apply(block).build() + +/** Kotlinish alias for [rawDataset]. */ +public suspend fun dataset(block: DataSourceDatasetBuilder.() -> Unit): RawDataset = + rawDataset(block) + +/** Builds a raw dataset with this resolver preconfigured. */ +public suspend fun DataSourceResolver.rawDataset(block: DataSourceDatasetBuilder.() -> Unit): RawDataset = + DataSourceDatasetBuilder(this).apply(block).build() + +/** Kotlinish alias for [DataSourceResolver.rawDataset]. */ +public suspend fun DataSourceResolver.dataset(block: DataSourceDatasetBuilder.() -> Unit): RawDataset = + rawDataset(block) diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataFormatParserTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataFormatParserTest.kt new file mode 100644 index 00000000..788d3c38 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataFormatParserTest.kt @@ -0,0 +1,125 @@ +package sk.ainet.data.source + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertSame + +class DataFormatParserTest { + @Test + fun parsesCsvHeaderAndQuotedValues() { + val dataset = DataFormatParserRegistry.default().parse( + DataFormat.CSV, + "name,label,comment\n" + + "Ada,math,\"first, second\"\n" + + "Grace,compiler,\"said \"\"hello\"\"\"" + ) + + assertEquals(listOf("name", "label", "comment"), dataset.schema.columns) + assertEquals(2, dataset.rows.size) + assertEquals("Ada", dataset.rows[0].values["name"]) + assertEquals("first, second", dataset.rows[0].values["comment"]) + assertEquals("said \"hello\"", dataset.rows[1].values["comment"]) + assertEquals("CSV", dataset.metadata["format"]) + assertEquals("2", dataset.metadata["rowCount"]) + } + + @Test + fun parsesTsvRows() { + val dataset = DataFormatParserRegistry.default().parse( + DataFormat.TSV, + "id\tvalue\n" + + "1\t42\n" + + "2\t84" + ) + + assertEquals(listOf("id", "value"), dataset.schema.columns) + assertEquals( + mapOf("id" to "2", "value" to "84"), + dataset.rows[1].values + ) + } + + @Test + fun parsesJsonLinesWithUnionSchema() { + val dataset = DataFormatParserRegistry.default().parse( + DataFormat.JSON_LINES, + "{\"id\":1,\"label\":\"cat\",\"pixels\":[0,1],\"meta\":{\"split\":\"train\"}}\n" + + "{\"id\":2,\"label\":\"dog\",\"score\":0.5}" + ) + + assertEquals(listOf("id", "label", "pixels", "meta", "score"), dataset.schema.columns) + assertEquals("1", dataset.rows[0].values["id"]) + assertEquals("[0,1]", dataset.rows[0].values["pixels"]) + assertEquals("{\"split\":\"train\"}", dataset.rows[0].values["meta"]) + assertEquals("", dataset.rows[0].values["score"]) + assertEquals("0.5", dataset.rows[1].values["score"]) + } + + @Test + fun parsesJsonArrayWithUnionSchema() { + val dataset = DataFormatParserRegistry.default().parse( + DataFormat.JSON, + "[" + + "{\"id\":1,\"label\":\"cat\",\"pixels\":[0,1]}," + + "{\"id\":2,\"label\":\"dog\",\"score\":0.5}" + + "]" + ) + + assertEquals(listOf("id", "label", "pixels", "score"), dataset.schema.columns) + assertEquals("JSON", dataset.metadata["format"]) + assertEquals("2", dataset.metadata["rowCount"]) + assertEquals("[0,1]", dataset.rows[0].values["pixels"]) + assertEquals("", dataset.rows[0].values["score"]) + } + + @Test + fun parsesJsonSingleObject() { + val dataset = DataFormatParserRegistry.default().parse( + DataFormat.JSON, + "{\"id\":1,\"label\":\"cat\"}" + ) + + assertEquals(listOf("id", "label"), dataset.schema.columns) + assertEquals(mapOf("id" to "1", "label" to "cat"), dataset.rows.single().values) + } + + @Test + fun rejectsJsonArraysWithNonObjectElements() { + assertFailsWith { + DataFormatParserRegistry.default().parse(DataFormat.JSON, "[1]") + } + } + + @Test + fun replacesRegisteredParser() { + val registry = DataFormatParserRegistry(parsers = emptyList()) + val parser = object : DataFormatParser { + override val format: DataFormat = DataFormat.CSV + + override fun parse(text: String): RawDataset = + RawDataset( + rows = listOf(RawDataRow(mapOf("value" to text))), + schema = DataSchema(listOf("value")) + ) + } + + registry.register(parser) + + assertSame(parser, registry.parserFor(DataFormat.CSV)) + assertEquals("payload", registry.parse(DataFormat.CSV, "payload").rows.single().values["value"]) + assertFailsWith { + registry.parserFor(DataFormat.TSV) + } + } + + @Test + fun rejectsDelimitedRowsWithWrongWidth() { + assertFailsWith { + DataFormatParserRegistry.default().parse( + DataFormat.CSV, + "a,b\n1,2,3" + ) + } + } +} diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataPipelineTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataPipelineTest.kt new file mode 100644 index 00000000..57693eba --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataPipelineTest.kt @@ -0,0 +1,68 @@ +package sk.ainet.data.source + +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class DataPipelineTest { + @Test + fun executesTypedStagesInOrder() = runTest { + val pipeline = dataPipeline() + .stage(stage("add-one") { input -> input + 1 }) + .stage(stage("stringify") { input -> "value=$input" }) + + assertEquals("add-one -> stringify", pipeline.describe()) + assertEquals(listOf("add-one", "stringify"), pipeline.stageNames) + assertEquals("value=42", pipeline.execute(41)) + } + + @Test + fun identityPipelineReturnsInput() = runTest { + val pipeline = dataPipeline() + + assertEquals("", pipeline.describe()) + assertEquals(7, pipeline.execute(7)) + } + + @Test + fun rejectsInvalidStageInput() = runTest { + val pipeline = dataPipeline() + .stage( + stage( + name = "positive", + validate = { input -> input > 0 } + ) { input -> input } + ) + + assertFailsWith { + pipeline.execute(0) + } + } + + @Test + fun transformerUpdatesSchemaAndRows() = runTest { + val dropLabel = dataTransformer( + name = "drop-label", + outputSchema = { schema -> DataSchema(schema.columns.filter { column -> column != "label" }) } + ) { dataset -> + val columns = dataset.schema.columns.filter { column -> column != "label" } + dataset.copy( + schema = DataSchema(columns), + rows = dataset.rows.map { row -> + RawDataRow(row.values.filterKeys { key -> key in columns }) + } + ) + } + val input = RawDataset( + rows = listOf(RawDataRow(mapOf("id" to "1", "label" to "cat"))), + schema = DataSchema(listOf("id", "label")) + ) + + val output = (dataPipeline() then dropLabel).execute(input) + + assertEquals(DataSchema(listOf("id")), dropLabel.getOutputSchema(input.schema)) + assertEquals(DataSchema(listOf("id")), output.schema) + assertEquals(mapOf("id" to "1"), output.rows.single().values) + } +} diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceDatasetBuilderTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceDatasetBuilderTest.kt new file mode 100644 index 00000000..ee6ec352 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceDatasetBuilderTest.kt @@ -0,0 +1,125 @@ +package sk.ainet.data.source + +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class DataSourceDatasetBuilderTest { + @Test + fun resolvesSourceAndInfersCsvFormat() = runTest { + val resolver = FixtureResolver("id,label\n1,cat\n2,dog") + val token = DataSourceAuthToken.from("hf_token") + + val dataset = resolver.rawDataset { + from("hf://datasets/org/repo@main/train.csv") + cachePolicy(CachePolicy.Refresh) + expectedSha256("sha256") + header("Accept", "text/csv") + huggingFaceToken(token) + } + + assertEquals(listOf("id", "label"), dataset.schema.columns) + assertEquals(mapOf("id" to "2", "label" to "dog"), dataset.rows[1].values) + assertEquals("CSV", dataset.metadata["format"]) + assertEquals("hf://datasets/org/repo@main/train.csv", dataset.metadata["sourceUri"]) + assertEquals("HuggingFace", dataset.metadata["sourceProvider"]) + assertEquals("train.csv", dataset.metadata["sourceFilename"]) + assertEquals("false", dataset.metadata["sourceCacheHit"]) + assertEquals("20", dataset.metadata["sourceSizeBytes"]) + + val request = resolver.lastRequest + assertEquals(CachePolicy.Refresh, request?.cachePolicy) + assertEquals("sha256", request?.expectedSha256) + assertEquals(mapOf("Accept" to "text/csv"), request?.headers) + assertEquals(token, request?.huggingFaceToken) + } + + @Test + fun explicitFormatOverridesFilenameInference() = runTest { + val resolver = FixtureResolver("id\tvalue\n1\t42") + + val dataset = resolver.rawDataset { + from("fixtures/train.txt") + format(DataFormat.TSV) + } + + assertEquals(listOf("id", "value"), dataset.schema.columns) + assertEquals(mapOf("id" to "1", "value" to "42"), dataset.rows.single().values) + } + + @Test + fun parserRegistrationReplacesDefaultParser() = runTest { + val resolver = FixtureResolver("payload") + val customParser = object : DataFormatParser { + override val format: DataFormat = DataFormat.CSV + + override fun parse(text: String): RawDataset = + RawDataset( + rows = listOf(RawDataRow(mapOf("raw" to text.uppercase()))), + schema = DataSchema(listOf("raw")) + ) + } + + val dataset = resolver.rawDataset { + from("fixture.csv") + parser(customParser) + } + + assertEquals(mapOf("raw" to "PAYLOAD"), dataset.rows.single().values) + } + + @Test + fun failsWhenFormatCannotBeInferred() = runTest { + val resolver = FixtureResolver("payload") + + assertFailsWith { + resolver.rawDataset { + from("fixture.bin") + } + } + } + + @Test + fun failsWhenResolverIsMissing() = runTest { + assertFailsWith { + rawDataset { + from("fixture.csv") + } + } + } + + @Test + fun failsWhenSourceIsMissing() = runTest { + val resolver = FixtureResolver("payload") + + assertFailsWith { + resolver.rawDataset { + format(DataFormat.CSV) + } + } + } +} + +private class FixtureResolver( + text: String +) : DataSourceResolver { + private val bytes = text.encodeToByteArray() + + var lastRequest: DataSourceRequest? = null + private set + + override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact { + lastRequest = request + val parsed = DataSourceUriParser.parse(request.uri) + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = parsed.localPath, + sizeBytes = bytes.size.toLong(), + cacheHit = false, + sourceOpener = { DataSourceStoredArtifact.inMemory(bytes, parsed.localPath).openSource() } + ) + } +}