From 0d631bf9d4bc1c85de96947526896c7717dea9b5 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Sun, 28 Jun 2026 21:15:37 +0200 Subject: [PATCH 01/18] =?UTF-8?q?diag:=20KernelProfile=20=E2=80=94=20time?= =?UTF-8?q?=20the=20three=20matmul=20dispatch=20paths?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Always-on accumulating profiler (quant-NEON / fp32-scalar / generic) on the DefaultCpuOps.matmul dispatch, read via KernelProfile.report(). Clock read per call is negligible next to a matmul. Used to localize native board decode cost: showed 100% of matmul time is the quant-NEON path (fp32-scalar/generic never hit). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../sk/ainet/exec/tensor/ops/DefaultCpuOps.kt | 40 +++++++------- .../sk/ainet/exec/tensor/ops/KernelProfile.kt | 55 +++++++++++++++++++ 2 files changed, 76 insertions(+), 19 deletions(-) create mode 100644 skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/KernelProfile.kt diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt index 6aad624e..b1bf0e72 100644 --- a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/DefaultCpuOps.kt @@ -389,37 +389,39 @@ public open class DefaultCpuOpsBase(protected val dataFactory: TensorDataFactory require(a.dtype == b.dtype) { "DType mismatch: ${a.dtype} vs ${b.dtype}" } // Packed-quant fast path (FP32 input × packed weight), resolved via KernelRegistry. - chooseQuantizedMatmulHeap(a, b)?.let { return it } + KernelProfile.timeQuant { chooseQuantizedMatmulHeap(a, b) }?.let { return it } // Fast path: 2D × 2D with FloatArray backing — direct buffer access, no per-element allocation if (a.rank == 2 && b.rank == 2 && (a.dtype == FP32::class) && a.data is FloatArrayTensorData<*> && b.data is FloatArrayTensorData<*> ) { - val aBuf = (a.data as FloatArrayTensorData<*>).buffer - val bBuf = (b.data as FloatArrayTensorData<*>).buffer - val m = a.shape[0] - val k = a.shape[1] - val n = b.shape[1] - require(k == b.shape[0]) { "Matrix multiplication shape mismatch: ${a.shape} vs ${b.shape}" } - val out = FloatArray(m * n) - for (i in 0 until m) { - val aOff = i * k - for (j in 0 until n) { - var sum = 0f - for (p in 0 until k) { - sum += aBuf[aOff + p] * bBuf[p * n + j] + return KernelProfile.timeFp32 { + val aBuf = (a.data as FloatArrayTensorData<*>).buffer + val bBuf = (b.data as FloatArrayTensorData<*>).buffer + val m = a.shape[0] + val k = a.shape[1] + val n = b.shape[1] + require(k == b.shape[0]) { "Matrix multiplication shape mismatch: ${a.shape} vs ${b.shape}" } + val out = FloatArray(m * n) + for (i in 0 until m) { + val aOff = i * k + for (j in 0 until n) { + var sum = 0f + for (p in 0 until k) { + sum += aBuf[aOff + p] * bBuf[p * n + j] + } + out[i * n + j] = sum } - out[i * n + j] = sum } + @Suppress("UNCHECKED_CAST") + val outData = dataFactory.fromFloatArray(Shape(m, n), a.dtype, out) as sk.ainet.lang.tensor.data.TensorData + newTensor(outData, a.dtype, a, b) } - @Suppress("UNCHECKED_CAST") - val outData = dataFactory.fromFloatArray(Shape(m, n), a.dtype, out) as sk.ainet.lang.tensor.data.TensorData - return newTensor(outData, a.dtype, a, b) } // Generic fallback for batched / non-float / non-2D cases - return matmulGeneric(a, b) + return KernelProfile.timeGeneric { matmulGeneric(a, b) } } private fun matmulGeneric( diff --git a/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/KernelProfile.kt b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/KernelProfile.kt new file mode 100644 index 00000000..0f20d85b --- /dev/null +++ b/skainet-backends/skainet-backend-cpu/src/commonMain/kotlin/sk/ainet/exec/tensor/ops/KernelProfile.kt @@ -0,0 +1,55 @@ +package sk.ainet.exec.tensor.ops + +import kotlin.time.TimeSource + +/** + * Lightweight, always-on accumulating profiler for the matmul dispatch paths. + * Diagnostic only — used to localize where native decode time goes (quant-NEON + * vs FP32-scalar vs generic) before investing in a kernel rewrite. The clock + * read per call is negligible next to a matmul. Read [report] after a run and + * [reset] between phases (e.g. to separate prefill from decode). + */ +public object KernelProfile { + private val clock = TimeSource.Monotonic + + public var quantNanos: Long = 0; private set + public var quantCalls: Long = 0; private set + public var fp32Nanos: Long = 0; private set + public var fp32Calls: Long = 0; private set + public var genericNanos: Long = 0; private set + public var genericCalls: Long = 0; private set + + public fun timeQuant(body: () -> R): R { + val mark = clock.markNow(); val r = body() + quantNanos += mark.elapsedNow().inWholeNanoseconds; quantCalls++; return r + } + + public fun timeFp32(body: () -> R): R { + val mark = clock.markNow(); val r = body() + fp32Nanos += mark.elapsedNow().inWholeNanoseconds; fp32Calls++; return r + } + + public fun timeGeneric(body: () -> R): R { + val mark = clock.markNow(); val r = body() + genericNanos += mark.elapsedNow().inWholeNanoseconds; genericCalls++; return r + } + + public fun reset() { + quantNanos = 0; quantCalls = 0 + fp32Nanos = 0; fp32Calls = 0 + genericNanos = 0; genericCalls = 0 + } + + public fun report(): String { + fun ms(ns: Long) = ns / 1_000_000.0 + val total = quantNanos + fp32Nanos + genericNanos + fun pct(ns: Long) = if (total > 0) 100.0 * ns / total else 0.0 + return buildString { + appendLine("[KernelProfile] matmul time breakdown:") + appendLine(" quant-NEON : ${ms(quantNanos)} ms over $quantCalls calls (${pct(quantNanos)}%)") + appendLine(" fp32-scalar : ${ms(fp32Nanos)} ms over $fp32Calls calls (${pct(fp32Nanos)}%)") + appendLine(" generic : ${ms(genericNanos)} ms over $genericCalls calls (${pct(genericNanos)}%)") + append(" matmul total : ${ms(total)} ms") + } + } +} From d998febeebeb7f91b3193cff5636a084d62ae27d Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 11:44:38 +0200 Subject: [PATCH 02/18] =?UTF-8?q?perf(native=20q4k):=20block-outer=20loop?= =?UTF-8?q?=20order=20+=20fused=20Q8=20int8=20dot=20=E2=80=94=202.07=C3=97?= =?UTF-8?q?=20matmul=20on=20A55?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two changes to skainet_q4k_matmul, both validated against the Panama reference (aggregate RMS gate, AGG_REL_TOL=0.03) and on-board generation: 1. Loop order block-OUTER / output-row-INNER. The weight is packed block-major (blockIdx*outputDim + o)*144, so for a fixed block consecutive `o` are exactly 144 bytes apart — weight bytes are now read strictly sequentially (prefetch/cache-line friendly). The previous o-outer order strided outputDim*144 (~295 KB on the down-proj) per step, making every weight read a cold miss on the in-order A55 with small caches. out_base[o] accumulates across blocks (stays hot in cache); accumulation order is unchanged so the result is numerically identical. 2. ggml-style Q8 activation quantization + integer vdotq_s32 dot path (asimddp), input row quantized once per 256-block and reused across all output rows; scalar integer fallback when dotprod is absent. On the SL2619 (Cortex-A55, TinyLlama Q4_K_M, 8-tok decode), Q4_K matmul dropped 41730 ms -> 20133 ms (2.07x); end-to-end decode 0.123 -> 0.184 tok/s (1.50x, matmul being ~64% of decode). The loop reorder is the dominant lever — the Q8 dot alone showed no gain because the kernel was memory-stall-bound, not compute-bound. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../native/src/q4k_matmul.c | 177 +++++++++++------- .../kernel/NativeQ4KMatmulKernelParityTest.kt | 32 +++- 2 files changed, 137 insertions(+), 72 deletions(-) diff --git a/skainet-backends/skainet-backend-native-cpu/native/src/q4k_matmul.c b/skainet-backends/skainet-backend-native-cpu/native/src/q4k_matmul.c index 8091e58a..21038410 100644 --- a/skainet-backends/skainet-backend-native-cpu/native/src/q4k_matmul.c +++ b/skainet-backends/skainet-backend-native-cpu/native/src/q4k_matmul.c @@ -3,6 +3,8 @@ #include #include +#include +#include #define Q4K_BLOCK_SIZE 256 #define Q4K_SUB_BLOCK_SIZE 32 @@ -60,21 +62,48 @@ static inline void skainet_q4k_decode_scales( } } +/* + * Quantize one 256-float input block to symmetric int8 (Q8) with a single + * per-block scale d_in = maxabs/127, q8[i] = round(in[i]/d_in). Returns d_in + * (0 if the block is all-zero, with q8 zeroed). Mirrors ggml's block_q8_K + * activation quantization — the source of the (small, well-understood) error + * vs the exact float kernel, and what unlocks the int8 dot-product fast path. + */ +static inline float skainet_q8_quantize_block(const float* SKAINET_RESTRICT in, int8_t* SKAINET_RESTRICT q8) { + float maxabs = 0.0f; + for (int i = 0; i < Q4K_BLOCK_SIZE; ++i) { + const float a = fabsf(in[i]); + if (a > maxabs) maxabs = a; + } + if (maxabs == 0.0f) { + for (int i = 0; i < Q4K_BLOCK_SIZE; ++i) q8[i] = 0; + return 0.0f; + } + const float d_in = maxabs / 127.0f; + const float inv = 127.0f / maxabs; + for (int i = 0; i < Q4K_BLOCK_SIZE; ++i) { + int v = (int) lrintf(in[i] * inv); + if (v > 127) v = 127; else if (v < -127) v = -127; + q8[i] = (int8_t) v; + } + return d_in; +} + /* * Native Q4_K matrix-vector multiply matching the - * sk.ainet.backend.api.kernel.Q4KMatmulKernel SPI contract. Single - * input row times an `outputDim x inputDim` Q4_K-packed weight tensor - * laid out (blockIdx * outputDim + o) * 144 bytes. - * - * Lazy-dmin pattern: per sub-block accumulate - * codeSum[s] = sum_i input[i] * code[i] - * inputSum[s] = sum_i input[i] - * and combine once via - * acc += d * scaleIdx[s] * codeSum[s] - dMin * minIdx[s] * inputSum[s] + * sk.ainet.backend.api.kernel.Q4KMatmulKernel SPI contract. Single input row + * times an `outputDim x inputDim` Q4_K-packed weight laid out + * (blockIdx * outputDim + o) * 144 bytes. * - * Scalar single-threaded for PR 2; the tight inner loop is - * straight-line FP arithmetic so -O3 auto-vectorizes the - * codeSum/inputSum accumulators on AVX2/NEON. + * Fused int8 dot path (ggml-style): the input row is quantized to Q8 ONCE per + * 256-block (reused across all output rows), then each weight sub-block is an + * int8 dot-product against the Q8 activation: + * acc += d_in[b] * ( d * Σ_s scaleIdx[s]*intDot[s] - dMin * Σ_s minIdx[s]*intSum[s] ) + * where intDot[s] = Σ q8[i]*code[i] and intSum[s] = Σ q8[i] over the sub-block. + * On AArch64 with dotprod (asimddp) the inner dot uses vdotq_s32 (16 int8 MACs + * per instruction); otherwise a scalar integer fallback (auto-vectorized). + * The index mapping (groups, lo/hi sub-blocks, input alignment) is identical to + * the previous float kernel, which was parity-checked against Panama. */ SKAINET_API void skainet_q4k_matmul( const float* SKAINET_RESTRICT input, @@ -92,86 +121,102 @@ SKAINET_API void skainet_q4k_matmul( const float* in_base = input + input_offset; float* out_base = output + output_offset; + /* Pre-quantize the whole input row to Q8 once (reused across all o). */ + int8_t* q8 = (int8_t*) malloc((size_t) input_dim * sizeof(int8_t)); + float* d_in = (float*) malloc((size_t) blocks_per_input_dim * sizeof(float)); + if (q8 == NULL || d_in == NULL) { free(q8); free(d_in); return; } + for (int32_t b = 0; b < blocks_per_input_dim; ++b) { + d_in[b] = skainet_q8_quantize_block(in_base + (size_t) b * Q4K_BLOCK_SIZE, + q8 + (size_t) b * Q4K_BLOCK_SIZE); + } + int scale_idx[Q4K_SUB_BLOCKS]; int min_idx[Q4K_SUB_BLOCKS]; - for (int32_t o = 0; o < output_dim; ++o) { - float acc = 0.0f; - - for (int32_t block_idx = 0; block_idx < blocks_per_input_dim; ++block_idx) { - const uint8_t* block = weight + weight_byte_offset - + (size_t)(block_idx * output_dim + o) * Q4K_BYTES_PER_BLOCK; - - /* d, dMin (FP16 LE -> FP32). */ + /* + * Loop order: block OUTER, output row INNER. The weight is packed + * block-major — (blockIdx * output_dim + o) * 144 — so for a fixed block, + * consecutive `o` are exactly 144 bytes apart: the weight bytes are read + * strictly sequentially (prefetch- and cache-line-friendly). The reverse + * order (o outer) strides output_dim*144 bytes per step (~295 KB on the + * down-proj), which on an in-order A55 with small caches makes every weight + * read a cold miss and dominates runtime regardless of inner-loop compute. + * out_base[o] is accumulated across blocks (output_dim*4 bytes stays hot in + * cache); the accumulation order over blocks is unchanged, so this is + * numerically identical to the o-outer form. + */ + for (int32_t o = 0; o < output_dim; ++o) out_base[o] = 0.0f; + + for (int32_t block_idx = 0; block_idx < blocks_per_input_dim; ++block_idx) { + const int8_t* q8_block = q8 + (size_t) block_idx * Q4K_BLOCK_SIZE; + const float di = d_in[block_idx]; + const uint8_t* block = weight + weight_byte_offset + + (size_t)(block_idx * output_dim) * Q4K_BYTES_PER_BLOCK; + + for (int32_t o = 0; o < output_dim; ++o, block += Q4K_BYTES_PER_BLOCK) { const uint16_t d_bits = (uint16_t) block[0] | ((uint16_t) block[1] << 8); const uint16_t d_min_bits = (uint16_t) block[2] | ((uint16_t) block[3] << 8); const float d = skainet_half_to_float(d_bits); const float d_min = skainet_half_to_float(d_min_bits); - /* 12 bytes of packed (scaleIdx, minIdx) -> 8 ints each. */ skainet_q4k_decode_scales(block + 4, scale_idx, min_idx); const uint8_t* qs = block + 16; - const float* in_block = in_base + (size_t) block_idx * Q4K_BLOCK_SIZE; - /* 4 strided qs groups; group j carries sub-blocks 2j (lo) and 2j+1 (hi). */ + int64_t block_scale_dot = 0; + int64_t block_min_sum = 0; + for (int group_j = 0; group_j < 4; ++group_j) { - const uint8_t* qs_group = qs + group_j * Q4K_SUB_BLOCK_SIZE; + const uint8_t* qs_group = qs + group_j * Q4K_SUB_BLOCK_SIZE; const int sb_lo = 2 * group_j; const int sb_hi = sb_lo + 1; - const float* in_lo = in_block + sb_lo * Q4K_SUB_BLOCK_SIZE; - const float* in_hi = in_block + sb_hi * Q4K_SUB_BLOCK_SIZE; + const int8_t* q8_lo = q8_block + sb_lo * Q4K_SUB_BLOCK_SIZE; + const int8_t* q8_hi = q8_block + sb_hi * Q4K_SUB_BLOCK_SIZE; - float code_sum_lo = 0.0f, input_sum_lo = 0.0f; - float code_sum_hi = 0.0f, input_sum_hi = 0.0f; + int32_t dot_lo = 0, sum_lo = 0, dot_hi = 0, sum_hi = 0; -#ifdef SKAINET_HAVE_NEON - float32x4_t cacc_lo = vdupq_n_f32(0.0f), iacc_lo = vdupq_n_f32(0.0f); - float32x4_t cacc_hi = vdupq_n_f32(0.0f), iacc_hi = vdupq_n_f32(0.0f); +#ifdef SKAINET_HAVE_DOTPROD + int32x4_t acc_dot_lo = vdupq_n_s32(0), acc_dot_hi = vdupq_n_s32(0); + int32_t acc_sum_lo = 0, acc_sum_hi = 0; for (int off = 0; off < Q4K_SUB_BLOCK_SIZE; off += 16) { const uint8x16_t packed = vld1q_u8(qs_group + off); - const uint8x16_t lo_nib = vandq_u8(packed, vdupq_n_u8(0x0F)); - const uint8x16_t hi_nib = vshrq_n_u8(packed, 4); - float32x4_t cl[4], ch[4]; - skainet_neon_u8x16_to_f32x4x4(lo_nib, cl); - skainet_neon_u8x16_to_f32x4x4(hi_nib, ch); - for (int q = 0; q < 4; ++q) { - const float32x4_t v_lo = vld1q_f32(in_lo + off + q * 4); - const float32x4_t v_hi = vld1q_f32(in_hi + off + q * 4); - cacc_lo = vfmaq_f32(cacc_lo, v_lo, cl[q]); - iacc_lo = vaddq_f32(iacc_lo, v_lo); - cacc_hi = vfmaq_f32(cacc_hi, v_hi, ch[q]); - iacc_hi = vaddq_f32(iacc_hi, v_hi); - } + const int8x16_t code_lo = vreinterpretq_s8_u8(vandq_u8(packed, vdupq_n_u8(0x0F))); + const int8x16_t code_hi = vreinterpretq_s8_u8(vshrq_n_u8(packed, 4)); + const int8x16_t a_lo = vld1q_s8(q8_lo + off); + const int8x16_t a_hi = vld1q_s8(q8_hi + off); + acc_dot_lo = vdotq_s32(acc_dot_lo, code_lo, a_lo); + acc_dot_hi = vdotq_s32(acc_dot_hi, code_hi, a_hi); + acc_sum_lo += vaddlvq_s8(a_lo); + acc_sum_hi += vaddlvq_s8(a_hi); } - code_sum_lo = skainet_neon_hadd_f32(cacc_lo); - input_sum_lo = skainet_neon_hadd_f32(iacc_lo); - code_sum_hi = skainet_neon_hadd_f32(cacc_hi); - input_sum_hi = skainet_neon_hadd_f32(iacc_hi); + dot_lo = vaddvq_s32(acc_dot_lo); + dot_hi = vaddvq_s32(acc_dot_hi); + sum_lo = acc_sum_lo; + sum_hi = acc_sum_hi; #else - /* 32 iterations — auto-vectorizes cleanly under -O3. */ for (int i = 0; i < Q4K_SUB_BLOCK_SIZE; ++i) { - const uint8_t b = qs_group[i]; - const float code_lo = (float)(b & 0x0F); - const float code_hi = (float)(b >> 4); - const float v_lo = in_lo[i]; - const float v_hi = in_hi[i]; - code_sum_lo += v_lo * code_lo; - input_sum_lo += v_lo; - code_sum_hi += v_hi * code_hi; - input_sum_hi += v_hi; + const uint8_t pb = qs_group[i]; + const int code_lo = (int)(pb & 0x0F); + const int code_hi = (int)(pb >> 4); + const int a_lo = (int) q8_lo[i]; + const int a_hi = (int) q8_hi[i]; + dot_lo += a_lo * code_lo; + sum_lo += a_lo; + dot_hi += a_hi * code_hi; + sum_hi += a_hi; } #endif - const float scale_lo = d * (float) scale_idx[sb_lo]; - const float offset_lo = d_min * (float) min_idx[sb_lo]; - const float scale_hi = d * (float) scale_idx[sb_hi]; - const float offset_hi = d_min * (float) min_idx[sb_hi]; - acc += code_sum_lo * scale_lo - input_sum_lo * offset_lo; - acc += code_sum_hi * scale_hi - input_sum_hi * offset_hi; + block_scale_dot += (int64_t) scale_idx[sb_lo] * dot_lo + + (int64_t) scale_idx[sb_hi] * dot_hi; + block_min_sum += (int64_t) min_idx[sb_lo] * sum_lo + + (int64_t) min_idx[sb_hi] * sum_hi; } - } - out_base[o] = acc; + out_base[o] += di * (d * (float) block_scale_dot - d_min * (float) block_min_sum); + } } + + free(q8); + free(d_in); } diff --git a/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4KMatmulKernelParityTest.kt b/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4KMatmulKernelParityTest.kt index 8e1c9546..7cc278cf 100644 --- a/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4KMatmulKernelParityTest.kt +++ b/skainet-backends/skainet-backend-native-cpu/src/jvmTest/kotlin/sk/ainet/exec/kernel/NativeQ4KMatmulKernelParityTest.kt @@ -72,14 +72,34 @@ class NativeQ4KMatmulKernelParityTest { nativeOut, 0, ) + // The native kernel quantizes the activation to int8 (Q8) for the dotprod + // fast path — deliberately lossy (ggml-style), so it is NOT bit-exact vs the + // float Panama reference. Per-row relative error is the WRONG gate here: + // on zero-mean uniform-random fixtures a row whose true value is ~0 shows an + // unbounded relative error from a tiny absolute one. The meaningful metric for + // a lossy kernel is the aggregate error energy — RMS(error) / RMS(signal) — + // which ggml-style validation uses. Real (smoother) LLM activations are far + // tighter than these worst-case fixtures; the end-to-end gate is the on-board + // generation output. + var sqErr = 0.0 + var sqSig = 0.0 for (o in 0 until outputDim) { - val diff = abs(refOut[o] - nativeOut[o]) - val rel = diff / (abs(refOut[o]) + 1e-9f) - assertTrue( - diff <= tol || rel < 1e-4f, - "row $o diverged: panama=${refOut[o]} native=${nativeOut[o]} diff=$diff rel=$rel tol=$tol", - ) + val d = (refOut[o] - nativeOut[o]).toDouble() + sqErr += d * d + sqSig += refOut[o].toDouble() * refOut[o].toDouble() } + val rmsErr = kotlin.math.sqrt(sqErr / outputDim) + val rmsSig = kotlin.math.sqrt(sqSig / outputDim) + val relRms = rmsErr / (rmsSig + 1e-9) + assertTrue( + relRms < AGG_REL_TOL || rmsErr < tol, + "Q8 parity exceeded: relRms=$relRms (rmsErr=$rmsErr rmsSig=$rmsSig) over $outputDim rows, tol=$AGG_REL_TOL", + ) + } + + private companion object { + // Aggregate Q8-activation RMS-relative-error bound (uniform-random worst case). + const val AGG_REL_TOL = 0.03 } @Test From 0366365572e85afa5986a3d869ccf2dfdbca1483 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 20:27:21 +0200 Subject: [PATCH 03/18] data: add dataset operation views --- .../skainet-data-api/build.gradle.kts | 2 + .../kotlin/sk/ainet/data/Dataset.kt | 224 +++++++++++++++++- .../sk/ainet/data/DatasetAndDataBatchTest.kt | 131 +++++++++- 3 files changed, 347 insertions(+), 10 deletions(-) diff --git a/skainet-data/skainet-data-api/build.gradle.kts b/skainet-data/skainet-data-api/build.gradle.kts index 175b2f3f..511226f9 100644 --- a/skainet-data/skainet-data-api/build.gradle.kts +++ b/skainet-data/skainet-data-api/build.gradle.kts @@ -46,11 +46,13 @@ kotlin { val commonMain by getting { dependencies { implementation(project(":skainet-lang:skainet-lang-core")) + implementation(libs.kotlinx.coroutines) } } commonTest.dependencies { implementation(libs.kotlin.test) + implementation(libs.kotlinx.coroutines.test) // implementation(project(":skainet-core:skainet-performance")) } } diff --git a/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/Dataset.kt b/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/Dataset.kt index abb111a8..cc72cae7 100644 --- a/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/Dataset.kt +++ b/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/Dataset.kt @@ -1,17 +1,30 @@ package sk.ainet.data +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import sk.ainet.lang.tensor.Shape import sk.ainet.lang.types.DType import kotlin.math.min +import kotlin.random.Random /** Just abstract Dataset. */ public abstract class Dataset { + /** Optional shape metadata for one input sample. */ + public open val inputShape: Shape? get() = null + + /** Optional shape metadata for one output sample. */ + public open val outputShape: Shape? get() = null + /** Splits datasets on two sub-datasets according [splitRatio].*/ public abstract fun split(splitRatio: Double): Pair, Dataset> /** Returns amount of data rows. */ public abstract val xSize: Int + /** Returns amount of data rows. Alias for [xSize]. */ + public val size: Int get() = xSize + /** Returns row by index [idx]. */ public abstract fun getX(idx: Int): T @@ -21,6 +34,60 @@ public abstract class Dataset { /** Shuffles the dataset. */ public abstract fun shuffle(): Dataset + /** Shuffles the dataset, using [seed] when deterministic ordering is required. */ + public open fun shuffle(seed: Long? = null): Dataset { + if (seed == null) return shuffle() + val indices = IntArray(xSize) { it } + indices.shuffle(Random(seed)) + return IndexedDataset(this, indices) + } + + /** + * Splits the dataset with optional deterministic shuffling and label stratification. + * + * The original [split] remains the compatibility path. This overload adds the + * ML-oriented behavior expected by training pipelines without forcing every + * concrete dataset to duplicate the same index bookkeeping. + */ + public open fun split( + splitRatio: Double, + seed: Long? = null, + stratified: Boolean = false + ): Pair, Dataset> { + require(splitRatio > 0.0 && splitRatio < 1.0) { "splitRatio must be in (0,1)" } + if (seed == null && !stratified) return split(splitRatio) + + val (leftIndices, rightIndices) = if (stratified) { + stratifiedSplitIndices(splitRatio, seed) + } else { + val indices = IntArray(xSize) { it } + seed?.let { indices.shuffle(Random(it)) } + indices.splitAtRatio(splitRatio) + } + + return IndexedDataset(this, leftIndices) to IndexedDataset(this, rightIndices) + } + + /** Returns a dataset view containing only samples accepted by [predicate]. */ + public fun filter(predicate: (T, Y) -> Boolean): Dataset { + val indices = (0 until xSize) + .filter { idx -> predicate(getX(idx), getY(idx)) } + .toIntArray() + return IndexedDataset(this, indices) + } + + /** Returns a dataset view with transformed input samples. */ + public fun mapX(transform: (T) -> NX): Dataset = + MappedDataset(this) { x, y -> transform(x) to y } + + /** Returns a dataset view with transformed target samples. */ + public fun mapY(transform: (Y) -> NY): Dataset = + MappedDataset(this) { x, y -> x to transform(y) } + + /** Returns a dataset view with transformed input and target samples. */ + public fun transform(transformer: (T, Y) -> Pair): Dataset = + MappedDataset(this, transformer) + /** * An iterator over a [Dataset]. */ @@ -43,9 +110,164 @@ public abstract class Dataset { /** Creates data batch that starts from [batchStart] with length [batchLength]. */ protected abstract fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch + /** + * Creates a data batch for arbitrary logical sample [indices]. + * + * Concrete datasets that can tensorize non-contiguous samples should override + * this method. The default path supports contiguous ranges and fails fast for + * non-contiguous index views instead of silently returning the wrong rows. + */ + protected open fun createIndexedDataBatch(indices: IntArray): DataBatch { + require(indices.isNotEmpty()) { "indices must not be empty" } + val first = indices.first() + require(first >= 0) { "indices must be non-negative" } + val contiguous = indices.withIndex().all { (offset, value) -> value == first + offset } + require(contiguous) { + "Non-contiguous data batches require createIndexedDataBatch(indices) support in the concrete dataset" + } + return createDataBatch(first, indices.size) + } + + /** Creates data batch that starts from [batchStart] with length [batchLength]. */ + public fun dataBatch(batchStart: Int, batchLength: Int): DataBatch { + require(batchStart >= 0) { "batchStart must be non-negative" } + require(batchLength >= 0) { "batchLength must be non-negative" } + require(batchStart + batchLength <= xSize) { "batch exceeds dataset size" } + return createDataBatch(batchStart, batchLength) + } + + /** Creates a data batch for arbitrary logical sample [indices]. */ + public fun dataBatch(indices: IntArray): DataBatch { + require(indices.all { it in 0 until xSize }) { "indices must be inside dataset bounds" } + return createIndexedDataBatch(indices) + } /** Returns [BatchIterator] with fixed [batchSize]. */ public fun batchIterator(batchSize: Int): BatchIterator { + require(batchSize > 0) { "batchSize must be positive" } return BatchIterator(batchSize) } -} \ No newline at end of file + + /** Returns a cold [Flow] of data batches. */ + public fun batches( + batchSize: Int, + shuffle: Boolean = true, + seed: Long? = null + ): Flow> = flow { + val source = if (shuffle) shuffle(seed) else this@Dataset + val iterator = source.batchIterator(batchSize) + while (iterator.hasNext()) { + emit(iterator.next()) + } + } + + /** Returns a cold [Flow] over [epochCount] epochs of data batches. */ + public fun epochs( + epochCount: Int, + batchSize: Int, + shuffle: Boolean = true, + seed: Long? = null + ): Flow> = flow { + require(epochCount >= 0) { "epochCount must be non-negative" } + for (epoch in 0 until epochCount) { + val epochSeed = seed?.plus(epoch) + val iterator = (if (shuffle) shuffle(epochSeed) else this@Dataset).batchIterator(batchSize) + while (iterator.hasNext()) { + emit(iterator.next()) + } + } + } + + private fun stratifiedSplitIndices(splitRatio: Double, seed: Long?): Pair { + val buckets = LinkedHashMap>() + for (idx in 0 until xSize) { + buckets.getOrPut(getY(idx)) { mutableListOf() }.add(idx) + } + + val random = seed?.let { Random(it) } + val left = mutableListOf() + val right = mutableListOf() + + for (bucket in buckets.values) { + val indices = bucket.toMutableList() + if (random != null) indices.shuffle(random) + val splitIndex = (indices.size * splitRatio).toInt().coerceIn(0, indices.size) + left.addAll(indices.subList(0, splitIndex)) + right.addAll(indices.subList(splitIndex, indices.size)) + } + + if (random != null) { + left.shuffle(random) + right.shuffle(random) + } + + return left.toIntArray() to right.toIntArray() + } +} + +private class IndexedDataset( + private val source: Dataset, + private val indices: IntArray +) : Dataset() { + override val inputShape: Shape? get() = source.inputShape + override val outputShape: Shape? get() = source.outputShape + override val xSize: Int get() = indices.size + + override fun getX(idx: Int): X = source.getX(indices[idx]) + + override fun getY(idx: Int): Y = source.getY(indices[idx]) + + override fun shuffle(): Dataset { + val shuffled = indices.copyOf() + shuffled.shuffle(Random.Default) + return IndexedDataset(source, shuffled) + } + + override fun split(splitRatio: Double): Pair, Dataset> { + require(splitRatio > 0.0 && splitRatio < 1.0) { "splitRatio must be in (0,1)" } + val (left, right) = indices.splitAtRatio(splitRatio) + return IndexedDataset(source, left) to IndexedDataset(source, right) + } + + override fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch { + val actualLength = min(batchLength, xSize - batchStart) + val batchIndices = IntArray(actualLength) { offset -> indices[batchStart + offset] } + return source.dataBatch(batchIndices) + } + + override fun createIndexedDataBatch(indices: IntArray): DataBatch { + val sourceIndices = IntArray(indices.size) { offset -> this.indices[indices[offset]] } + return source.dataBatch(sourceIndices) + } +} + +private class MappedDataset( + private val source: Dataset, + private val transformer: (SX, SY) -> Pair +) : Dataset() { + override val xSize: Int get() = source.xSize + + override fun getX(idx: Int): TX { + val (x, _) = transformer(source.getX(idx), source.getY(idx)) + return x + } + + override fun getY(idx: Int): TY { + val (_, y) = transformer(source.getX(idx), source.getY(idx)) + return y + } + + override fun shuffle(): Dataset = shuffle(Random.nextLong()) + + override fun split(splitRatio: Double): Pair, Dataset> = + split(splitRatio, seed = null, stratified = false) + + override fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch { + throw UnsupportedOperationException("MappedDataset cannot create tensor batches without a tensorization transform") + } +} + +private fun IntArray.splitAtRatio(splitRatio: Double): Pair { + val splitIndex = (size * splitRatio).toInt().coerceIn(0, size) + return copyOfRange(0, splitIndex) to copyOfRange(splitIndex, size) +} diff --git a/skainet-data/skainet-data-api/src/commonTest/kotlin/sk/ainet/data/DatasetAndDataBatchTest.kt b/skainet-data/skainet-data-api/src/commonTest/kotlin/sk/ainet/data/DatasetAndDataBatchTest.kt index e8395b94..6643fdfd 100644 --- a/skainet-data/skainet-data-api/src/commonTest/kotlin/sk/ainet/data/DatasetAndDataBatchTest.kt +++ b/skainet-data/skainet-data-api/src/commonTest/kotlin/sk/ainet/data/DatasetAndDataBatchTest.kt @@ -1,5 +1,7 @@ package sk.ainet.data +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.test.assertEquals import sk.ainet.context.DefaultDataExecutionContext @@ -43,15 +45,25 @@ private class FakeDataset( @Suppress("UNCHECKED_CAST") override fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch { val end = (batchStart + batchLength).coerceAtMost(features.size) - val sliceF = features.subList(batchStart, end) - val sliceY = labels.subList(batchStart, end) - val featureSize = if (sliceF.isNotEmpty()) sliceF[0].size else 0 + val indices = (batchStart until end).toList() + return createBatchForIndices(indices) as DataBatch + } + + @Suppress("UNCHECKED_CAST") + override fun createIndexedDataBatch(indices: IntArray): DataBatch { + return createBatchForIndices(indices.toList()) as DataBatch + } + + private fun createBatchForIndices(indices: List): DataBatch { + val sliceF = indices.map { features[it] } + val sliceY = indices.map { labels[it] } + val featureSize = sliceF.firstOrNull()?.size ?: 0 // Build a single input tensor [batch, feature] val xTensor: Tensor = tensor(ctx, FP32::class) { tensor { - shape(end - batchStart, featureSize) { - val flat = FloatArray((end - batchStart) * featureSize) + shape(indices.size, featureSize) { + val flat = FloatArray(indices.size * featureSize) var k = 0 for (i in sliceF.indices) { val row = sliceF[i] @@ -67,15 +79,14 @@ private class FakeDataset( // Build label tensor [batch] val yTensor: Tensor = tensor(ctx, FP32::class) { tensor { - shape(end - batchStart) { - val flat = FloatArray(end - batchStart) { idx -> sliceY[idx] } + shape(indices.size) { + val flat = FloatArray(indices.size) { idx -> sliceY[idx] } fromArray(flat) } } } - val batch = DataBatch(arrayOf(xTensor), yTensor) - return batch as DataBatch + return DataBatch(arrayOf(xTensor), yTensor) } } @@ -168,4 +179,106 @@ class DatasetAndDataBatchTest { val b = (0 until shuffled.xSize).map { shuffled.getX(it).joinToString(",") }.sorted() assertEquals(a, b) } + + @Test + fun sizeAliasMatchesXSize() { + val ds = FakeDataset( + features = (1..3).map { i -> floatArrayOf(i.toFloat()) }, + labels = listOf(0f, 1f, 0f) + ) + + assertEquals(ds.xSize, ds.size) + } + + @Test + fun seededShuffleIsDeterministic() { + val ds = FakeDataset( + features = (1..12).map { i -> floatArrayOf(i.toFloat()) }, + labels = (1..12).map { (it % 2).toFloat() } + ) + + val first = ds.shuffle(seed = 42) + val second = ds.shuffle(seed = 42) + val third = ds.shuffle(seed = 99) + + val firstOrder = (0 until first.xSize).map { first.getX(it)[0] } + val secondOrder = (0 until second.xSize).map { second.getX(it)[0] } + val thirdOrder = (0 until third.xSize).map { third.getX(it)[0] } + + assertEquals(firstOrder, secondOrder) + assertNotEquals(firstOrder, thirdOrder) + } + + @Test + fun seededSplitIsDeterministic() { + val ds = FakeDataset( + features = (1..10).map { i -> floatArrayOf(i.toFloat()) }, + labels = (1..10).map { (it % 2).toFloat() } + ) + + val (trainA, testA) = ds.split(splitRatio = 0.7, seed = 123) + val (trainB, testB) = ds.split(splitRatio = 0.7, seed = 123) + + assertEquals((0 until trainA.xSize).map { trainA.getX(it)[0] }, (0 until trainB.xSize).map { trainB.getX(it)[0] }) + assertEquals((0 until testA.xSize).map { testA.getX(it)[0] }, (0 until testB.xSize).map { testB.getX(it)[0] }) + } + + @Test + fun stratifiedSplitKeepsBothLabelsInBothSides() { + val ds = FakeDataset( + features = (1..20).map { i -> floatArrayOf(i.toFloat()) }, + labels = (1..20).map { (it % 2).toFloat() } + ) + + val (train, test) = ds.split(splitRatio = 0.5, seed = 7, stratified = true) + + assertEquals(setOf(0f, 1f), (0 until train.xSize).map { train.getY(it) }.toSet()) + assertEquals(setOf(0f, 1f), (0 until test.xSize).map { test.getY(it) }.toSet()) + } + + @Test + fun filterAndMapCreateDatasetViews() { + val ds = FakeDataset( + features = (1..6).map { i -> floatArrayOf(i.toFloat()) }, + labels = (1..6).map { (it % 2).toFloat() } + ) + + val filtered = ds.filter { _, label -> label == 0f } + val mapped = filtered + .mapX { features -> features[0].toInt() } + .mapY { label -> "class-${label.toInt()}" } + + assertEquals(3, mapped.xSize) + assertEquals(listOf(2, 4, 6), (0 until mapped.xSize).map { mapped.getX(it) }) + assertEquals(listOf("class-0", "class-0", "class-0"), (0 until mapped.xSize).map { mapped.getY(it) }) + } + + @Test + fun batchesFlowEmitsAllBatches() = runTest { + val ds = FakeDataset( + features = (1..5).map { i -> floatArrayOf(i.toFloat(), (i * 10).toFloat()) }, + labels = (1..5).map { (it % 2).toFloat() } + ) + + val batches = ds.batches(batchSize = 2, shuffle = false).toList() + + assertEquals(3, batches.size) + assertEquals(2, batches[0].y.shape[0]) + assertEquals(2, batches[1].y.shape[0]) + assertEquals(1, batches[2].y.shape[0]) + } + + @Test + fun epochsFlowUsesSeededShufflePerEpoch() = runTest { + val ds = FakeDataset( + features = (1..6).map { i -> floatArrayOf(i.toFloat()) }, + labels = (1..6).map { (it % 2).toFloat() } + ) + + val batches = ds.epochs(epochCount = 2, batchSize = 3, shuffle = true, seed = 11).toList() + + assertEquals(4, batches.size) + assertEquals(3, batches[0].y.shape[0]) + assertEquals(3, batches[3].y.shape[0]) + } } From 4d4cb01a9da1904e94371a8df4f6778c03caef5b Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 20:28:57 +0200 Subject: [PATCH 04/18] data: enrich data batch metadata --- .../kotlin/sk/ainet/data/DataBatch.kt | 43 +++++++++++++++++- .../sk/ainet/data/DatasetAndDataBatchTest.kt | 44 +++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/DataBatch.kt b/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/DataBatch.kt index 4183ff89..51617053 100644 --- a/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/DataBatch.kt +++ b/skainet-data/skainet-data-api/src/commonMain/kotlin/sk/ainet/data/DataBatch.kt @@ -1,10 +1,37 @@ package sk.ainet.data import sk.ainet.lang.tensor.Tensor +import sk.ainet.lang.tensor.sliceView import sk.ainet.lang.types.DType -public data class DataBatch(val x: Array>, val y: Tensor) { +public data class DataBatch( + val x: Array>, + val y: Tensor, + val indices: IntArray = IntArray(y.shape[0]) { it }, + val metadata: Map = emptyMap() +) { + /** Number of samples represented by this batch. */ + public val batchSize: Int get() = indices.size + + /** Returns a copy with [metadata] merged into the existing metadata map. */ + public fun withMetadata(metadata: Map): DataBatch = + copy(metadata = this.metadata + metadata) + + /** Returns a contiguous slice of this batch over its leading batch dimension. */ + public fun slice(range: IntRange): DataBatch { + require(!range.isEmpty()) { "range must not be empty" } + require(range.first >= 0) { "range start must be non-negative" } + require(range.last < batchSize) { "range end must be within batch bounds" } + + val endExclusive = range.last + 1 + return copy( + x = x.map { tensor -> tensor.sliceLeadingDimension(range.first, endExclusive) }.toTypedArray(), + y = y.sliceLeadingDimension(range.first, endExclusive), + indices = indices.sliceArray(range) + ) + } + override fun equals(other: Any?): Boolean { if (this === other) return true if (other == null || this::class != other::class) return false @@ -13,6 +40,8 @@ public data class DataBatch(val x: Array>, val y: Ten if (!x.contentEquals(other.x)) return false if (y != other.y) return false + if (!indices.contentEquals(other.indices)) return false + if (metadata != other.metadata) return false return true } @@ -20,6 +49,18 @@ public data class DataBatch(val x: Array>, val y: Ten override fun hashCode(): Int { var result = x.contentHashCode() result = 31 * result + y.hashCode() + result = 31 * result + indices.contentHashCode() + result = 31 * result + metadata.hashCode() return result } } + +private fun Tensor.sliceLeadingDimension(start: Int, endExclusive: Int): Tensor { + if (rank == 0) return this + return sliceView { + segment { range(start, endExclusive) } + repeat(rank - 1) { + segment { all() } + } + } +} diff --git a/skainet-data/skainet-data-api/src/commonTest/kotlin/sk/ainet/data/DatasetAndDataBatchTest.kt b/skainet-data/skainet-data-api/src/commonTest/kotlin/sk/ainet/data/DatasetAndDataBatchTest.kt index 6643fdfd..daec89b1 100644 --- a/skainet-data/skainet-data-api/src/commonTest/kotlin/sk/ainet/data/DatasetAndDataBatchTest.kt +++ b/skainet-data/skainet-data-api/src/commonTest/kotlin/sk/ainet/data/DatasetAndDataBatchTest.kt @@ -111,6 +111,50 @@ class DatasetAndDataBatchTest { assertNotEquals(batchA, batchC, "Batches with same values but different tensor instances should not be equal") } + @Test + fun dataBatchCarriesIndicesAndMetadata() { + val ctx = DefaultDataExecutionContext() + val x: Tensor = data(ctx) { tensor { shape(2, 2) { from(1f, 2f, 3f, 4f) } } } + val y: Tensor = tensor(ctx, FP32::class) { tensor { shape(2) { from(0f, 1f) } } } + + val batch = DataBatch( + x = arrayOf(x), + y = y, + indices = intArrayOf(10, 20), + metadata = mapOf("split" to "train") + ) + val enriched = batch.withMetadata(mapOf("epoch" to "1")) + + assertEquals(2, batch.batchSize) + assertEquals(mapOf("split" to "train", "epoch" to "1"), enriched.metadata) + assertNotEquals(batch, batch.copy(indices = intArrayOf(11, 20))) + } + + @Test + fun dataBatchSliceUsesLeadingDimension() { + val ctx = DefaultDataExecutionContext() + val x: Tensor = data(ctx) { + tensor { + shape(3, 2) { + from(1f, 2f, 3f, 4f, 5f, 6f) + } + } + } + val y: Tensor = tensor(ctx, FP32::class) { tensor { shape(3) { from(0f, 1f, 2f) } } } + val batch = DataBatch(arrayOf(x), y, indices = intArrayOf(4, 5, 6)) + + val sliced = batch.slice(1..2) + + assertEquals(2, sliced.batchSize) + assertEquals(listOf(5, 6), sliced.indices.toList()) + assertEquals(2, sliced.x[0].shape[0]) + assertEquals(2, sliced.y.shape[0]) + assertEquals(3f, sliced.x[0].data[0, 0]) + assertEquals(6f, sliced.x[0].data[1, 1]) + assertEquals(1f, sliced.y.data[0]) + assertEquals(2f, sliced.y.data[1]) + } + @Test fun batchIteratorProducesCorrectSlices() { val feats = listOf( From a8642b9b6ef2253236d0b3b591b2094f60d812c5 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 13:36:55 +0200 Subject: [PATCH 05/18] data: add URI source contracts --- settings.gradle.kts | 1 + .../skainet-data-source/build.gradle.kts | 38 ++++ .../skainet-data-source/gradle.properties | 2 + .../sk/ainet/data/source/DataSourceModels.kt | 102 +++++++++ .../ainet/data/source/DataSourceUriParser.kt | 200 ++++++++++++++++++ .../data/source/DataSourceUriParserTest.kt | 80 +++++++ 6 files changed, 423 insertions(+) create mode 100644 skainet-data/skainet-data-source/build.gradle.kts create mode 100644 skainet-data/skainet-data-source/gradle.properties create mode 100644 skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt create mode 100644 skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt create mode 100644 skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt diff --git a/settings.gradle.kts b/settings.gradle.kts index bbfbc825..5e25ef39 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -48,6 +48,7 @@ include("skainet-backends:benchmarks:jvm-cpu-publish") // ====== DATA include("skainet-data:skainet-data-api") +include("skainet-data:skainet-data-source") include("skainet-data:skainet-data-transform") include("skainet-data:skainet-data-simple") include("skainet-data:skainet-data-media") diff --git a/skainet-data/skainet-data-source/build.gradle.kts b/skainet-data/skainet-data-source/build.gradle.kts new file mode 100644 index 00000000..f8b4dbc0 --- /dev/null +++ b/skainet-data/skainet-data-source/build.gradle.kts @@ -0,0 +1,38 @@ +import org.jetbrains.kotlin.gradle.dsl.JvmTarget + +plugins { + alias(libs.plugins.kotlinMultiplatform) + alias(libs.plugins.vanniktech.mavenPublish) + id("sk.ainet.dokka") +} + +kotlin { + explicitApi() + + jvm { + compilerOptions { + jvmTarget.set(JvmTarget.JVM_11) + } + } + + sourceSets { + commonMain.dependencies { + implementation(libs.kotlinx.coroutines) + } + + commonTest.dependencies { + implementation(libs.kotlin.test) + } + + jvmMain.dependencies { + implementation(libs.ktor.client.cio) + implementation(libs.ktor.client.core) + implementation(libs.ktor.client.plugins) + implementation(libs.kotlinx.coroutines.core.jvm) + } + + jvmTest.dependencies { + implementation(libs.kotlinx.coroutines.test) + } + } +} diff --git a/skainet-data/skainet-data-source/gradle.properties b/skainet-data/skainet-data-source/gradle.properties new file mode 100644 index 00000000..3516f9dd --- /dev/null +++ b/skainet-data/skainet-data-source/gradle.properties @@ -0,0 +1,2 @@ +POM_ARTIFACT_ID=skainet-data-source +POM_NAME=skainet data source diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt new file mode 100644 index 00000000..71eacf32 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt @@ -0,0 +1,102 @@ +package sk.ainet.data.source + +/** + * Cache behavior requested by a caller resolving a data artifact. + */ +public enum class CachePolicy { + /** Use a cached artifact when present, otherwise fetch and cache it. */ + Use, + + /** Fetch the artifact again and replace any existing cached copy. */ + Refresh, + + /** Require an existing cached or local artifact; do not use the network. */ + Offline, + + /** Fetch or read the artifact without writing a persistent cache entry. */ + Bypass +} + +/** + * High-level provider implied by a source URI. + */ +public enum class DataSourceProvider { + File, + Http, + HuggingFace +} + +/** + * Hugging Face repository namespace encoded by an `hf://` URI. + */ +public enum class HuggingFaceRepoType { + Model, + Dataset, + Space +} + +/** + * Parsed Hugging Face location, when a URI uses SKaiNET's `hf://` shorthand + * or the explicit `hf+https://...` provider prefix. + */ +public data class HuggingFaceLocation( + public val repoType: HuggingFaceRepoType, + public val repoId: String?, + public val revision: String?, + public val path: String? +) + +/** + * A normalized, provider-aware source URI. + */ +public data class ParsedDataSourceUri( + public val rawUri: String, + public val provider: DataSourceProvider, + public val transportUri: String, + public val filename: String, + public val cacheKey: String, + public val localPath: String? = null, + public val huggingFace: HuggingFaceLocation? = null +) + +/** + * Request to resolve a local or remote artifact. + */ +public data class DataSourceRequest( + public val uri: String, + public val cachePolicy: CachePolicy = CachePolicy.Use, + public val expectedSha256: String? = null, + public val headers: Map = emptyMap() +) + +/** + * A resolved artifact. Remote artifacts may expose a [localPath] when they + * have been materialized into a platform cache. + */ +public class DataSourceArtifact( + public val request: DataSourceRequest, + public val parsedUri: ParsedDataSourceUri, + public val filename: String, + public val localPath: String?, + public val sizeBytes: Long?, + public val cacheHit: Boolean, + private val byteReader: suspend () -> ByteArray +) { + public suspend fun readBytes(): ByteArray = byteReader() +} + +/** + * Resolves source URIs into readable data artifacts. + */ +public interface DataSourceResolver { + public suspend fun resolve(request: DataSourceRequest): DataSourceArtifact +} + +public open class DataSourceException( + message: String, + cause: Throwable? = null +) : RuntimeException(message, cause) + +public class UnsupportedDataSourceUriException( + uri: String +) : DataSourceException("Unsupported data source URI: $uri") diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt new file mode 100644 index 00000000..06378277 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt @@ -0,0 +1,200 @@ +package sk.ainet.data.source + +/** + * Parses SKaiNET data source URIs. + * + * Supported forms: + * - `file:///absolute/path` + * - `/absolute/or/relative/path` + * - `https://host/path` + * - `hf+https://huggingface.co/org/repo/resolve/main/file` + * - `hf://org/repo@revision/path/to/file` + * - `hf://datasets/org/repo@revision/path/to/file` + */ +public object DataSourceUriParser { + public fun parse(uri: String): ParsedDataSourceUri { + val raw = uri.trim() + require(raw.isNotEmpty()) { "Data source URI must not be blank" } + + return when { + raw.startsWith(HF_HTTPS_PREFIX) -> parseHfHttps(raw) + raw.startsWith(HF_URI_PREFIX) -> parseHfUri(raw) + raw.startsWith(FILE_URI_PREFIX) -> parseFileUri(raw) + raw.startsWith(HTTPS_PREFIX) || raw.startsWith(HTTP_PREFIX) -> parseHttp(raw) + raw.contains("://") -> throw UnsupportedDataSourceUriException(raw) + else -> parsePlainFilePath(raw) + } + } + + private fun parseFileUri(raw: String): ParsedDataSourceUri { + val localPath = normalizeFileUriPath(raw.removePrefix(FILE_URI_PREFIX)) + val filename = extractFilename(localPath) + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.File, + transportUri = raw, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.File, localPath, filename), + localPath = localPath + ) + } + + private fun parsePlainFilePath(raw: String): ParsedDataSourceUri { + val filename = extractFilename(raw) + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.File, + transportUri = raw, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.File, raw, filename), + localPath = raw + ) + } + + private fun parseHttp(raw: String): ParsedDataSourceUri { + val filename = extractFilename(raw) + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.Http, + transportUri = raw, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.Http, raw, filename) + ) + } + + private fun parseHfHttps(raw: String): ParsedDataSourceUri { + val transportUri = raw.removePrefix("hf+") + val filename = extractFilename(transportUri) + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.HuggingFace, + transportUri = transportUri, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.HuggingFace, transportUri, filename), + huggingFace = HuggingFaceLocation( + repoType = HuggingFaceRepoType.Model, + repoId = null, + revision = null, + path = null + ) + ) + } + + private fun parseHfUri(raw: String): ParsedDataSourceUri { + val body = raw.removePrefix(HF_URI_PREFIX).trim('/') + val segments = body.split('/').filter { it.isNotBlank() } + require(segments.size >= 3) { + "hf:// URI must include repo owner, repo name, and file path: $raw" + } + + val (repoType, repoStart) = when (segments.first()) { + "models", "model" -> HuggingFaceRepoType.Model to 1 + "datasets", "dataset" -> HuggingFaceRepoType.Dataset to 1 + "spaces", "space" -> HuggingFaceRepoType.Space to 1 + else -> HuggingFaceRepoType.Model to 0 + } + require(segments.size - repoStart >= 3) { + "hf:// URI must include repo owner, repo name, and file path: $raw" + } + + val owner = segments[repoStart] + val repoAndRevision = segments[repoStart + 1] + val repoName = repoAndRevision.substringBefore('@') + val revision = repoAndRevision.substringAfter('@', "main") + val filePath = segments.drop(repoStart + 2).joinToString("/") + val repoId = "$owner/$repoName" + val prefix = when (repoType) { + HuggingFaceRepoType.Model -> "" + HuggingFaceRepoType.Dataset -> "datasets/" + HuggingFaceRepoType.Space -> "spaces/" + } + val transportUri = "https://huggingface.co/$prefix$repoId/resolve/$revision/$filePath" + val filename = extractFilename(filePath) + + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.HuggingFace, + transportUri = transportUri, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.HuggingFace, transportUri, filename), + huggingFace = HuggingFaceLocation( + repoType = repoType, + repoId = repoId, + revision = revision, + path = filePath + ) + ) + } + + private fun normalizeFileUriPath(path: String): String { + val withoutLocalhost = path.removePrefix("localhost/") + val normalized = if (withoutLocalhost.startsWith("/")) withoutLocalhost else "/$withoutLocalhost" + return percentDecode(normalized) + } + + private fun extractFilename(value: String): String { + val withoutFragment = value.substringBefore('#').substringBefore('?').trimEnd('/') + val filename = withoutFragment.substringAfterLast('/', missingDelimiterValue = withoutFragment) + return percentDecode(filename).ifBlank { "artifact" } + } + + private fun percentDecode(value: String): String { + val out = StringBuilder(value.length) + var i = 0 + while (i < value.length) { + val c = value[i] + if (c == '%' && i + 2 < value.length) { + val decoded = hexByte(value[i + 1], value[i + 2]) + if (decoded != null) { + out.append(decoded.toInt().toChar()) + i += 3 + continue + } + } + out.append(c) + i++ + } + return out.toString() + } + + private fun hexByte(high: Char, low: Char): Byte? { + val hi = high.digitToIntOrNull(16) ?: return null + val lo = low.digitToIntOrNull(16) ?: return null + return ((hi shl 4) or lo).toByte() + } + + private fun cacheKey(provider: DataSourceProvider, normalizedUri: String, filename: String): String { + val safeName = filename.map { ch -> + if (ch.isLetterOrDigit() || ch == '.' || ch == '-' || ch == '_') ch else '_' + }.joinToString("") + return "${provider.name.lowercase()}-${fnv1a32Hex(normalizedUri)}-$safeName" + } + + private fun fnv1a32Hex(value: String): String { + var hash = FNV_OFFSET + val bytes = value.encodeToByteArray() + for (byte in bytes) { + hash = hash xor (byte.toInt() and 0xff) + hash *= FNV_PRIME + } + return hash.toHex8() + } + + private fun Int.toHex8(): String { + val chars = CharArray(8) + for (i in chars.indices) { + val shift = (7 - i) * 4 + chars[i] = HEX[(this ushr shift) and 0x0f] + } + return chars.concatToString() + } + + private const val FILE_URI_PREFIX = "file://" + private const val HTTP_PREFIX = "http://" + private const val HTTPS_PREFIX = "https://" + private const val HF_HTTPS_PREFIX = "hf+https://" + private const val HF_URI_PREFIX = "hf://" + private const val FNV_OFFSET = -2128831035 + private const val FNV_PRIME = 16777619 + private val HEX: CharArray = "0123456789abcdef".toCharArray() +} diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt new file mode 100644 index 00000000..9076c5be --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt @@ -0,0 +1,80 @@ +package sk.ainet.data.source + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotEquals +import kotlin.test.assertNull + +class DataSourceUriParserTest { + @Test + fun parsesFileUri() { + val parsed = DataSourceUriParser.parse("file:///tmp/skainet/train-images.idx") + + assertEquals(DataSourceProvider.File, parsed.provider) + assertEquals("/tmp/skainet/train-images.idx", parsed.localPath) + assertEquals("train-images.idx", parsed.filename) + } + + @Test + fun parsesPlainPathAsFile() { + val parsed = DataSourceUriParser.parse("fixtures/mnist/train-labels.idx") + + assertEquals(DataSourceProvider.File, parsed.provider) + assertEquals("fixtures/mnist/train-labels.idx", parsed.localPath) + assertEquals("train-labels.idx", parsed.filename) + } + + @Test + fun parsesHttpUri() { + val parsed = DataSourceUriParser.parse("https://example.test/data/sample.csv?download=1") + + assertEquals(DataSourceProvider.Http, parsed.provider) + assertEquals("sample.csv", parsed.filename) + assertNull(parsed.huggingFace) + } + + @Test + fun parsesHuggingFaceHttpsProviderPrefix() { + val parsed = DataSourceUriParser.parse( + "hf+https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json" + ) + + assertEquals(DataSourceProvider.HuggingFace, parsed.provider) + assertEquals( + "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json", + parsed.transportUri + ) + assertEquals("tokenizer.json", parsed.filename) + } + + @Test + fun parsesHuggingFaceDatasetShorthand() { + val parsed = DataSourceUriParser.parse("hf://datasets/mnist/mnist@main/plain_text/train-00000.parquet") + + assertEquals(DataSourceProvider.HuggingFace, parsed.provider) + assertEquals(HuggingFaceRepoType.Dataset, parsed.huggingFace?.repoType) + assertEquals("mnist/mnist", parsed.huggingFace?.repoId) + assertEquals("main", parsed.huggingFace?.revision) + assertEquals("plain_text/train-00000.parquet", parsed.huggingFace?.path) + assertEquals( + "https://huggingface.co/datasets/mnist/mnist/resolve/main/plain_text/train-00000.parquet", + parsed.transportUri + ) + } + + @Test + fun cacheKeyDependsOnNormalizedUri() { + val first = DataSourceUriParser.parse("https://example.test/a.txt") + val second = DataSourceUriParser.parse("https://example.test/b.txt") + + assertNotEquals(first.cacheKey, second.cacheKey) + } + + @Test + fun rejectsUnknownSchemes() { + assertFailsWith { + DataSourceUriParser.parse("s3://bucket/object") + } + } +} From 18fdae77266bf7b6d9cb7c5f95edb5a3621d0c2a Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 13:38:41 +0200 Subject: [PATCH 06/18] data: materialize JVM source artifacts --- .../ainet/data/source/DataSourceUriParser.kt | 15 +- .../data/source/DataSourceUriParserTest.kt | 9 + .../data/source/JvmDataSourceResolver.kt | 169 ++++++++++++++++++ .../data/source/JvmDataSourceResolverTest.kt | 162 +++++++++++++++++ 4 files changed, 350 insertions(+), 5 deletions(-) create mode 100644 skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt create mode 100644 skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt index 06378277..9281358a 100644 --- a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt @@ -19,7 +19,7 @@ public object DataSourceUriParser { return when { raw.startsWith(HF_HTTPS_PREFIX) -> parseHfHttps(raw) raw.startsWith(HF_URI_PREFIX) -> parseHfUri(raw) - raw.startsWith(FILE_URI_PREFIX) -> parseFileUri(raw) + raw.startsWith(FILE_URI_SCHEME) -> parseFileUri(raw) raw.startsWith(HTTPS_PREFIX) || raw.startsWith(HTTP_PREFIX) -> parseHttp(raw) raw.contains("://") -> throw UnsupportedDataSourceUriException(raw) else -> parsePlainFilePath(raw) @@ -27,7 +27,7 @@ public object DataSourceUriParser { } private fun parseFileUri(raw: String): ParsedDataSourceUri { - val localPath = normalizeFileUriPath(raw.removePrefix(FILE_URI_PREFIX)) + val localPath = normalizeFileUriPath(raw.removePrefix(FILE_URI_SCHEME)) val filename = extractFilename(localPath) return ParsedDataSourceUri( rawUri = raw, @@ -127,8 +127,13 @@ public object DataSourceUriParser { } private fun normalizeFileUriPath(path: String): String { - val withoutLocalhost = path.removePrefix("localhost/") - val normalized = if (withoutLocalhost.startsWith("/")) withoutLocalhost else "/$withoutLocalhost" + val normalized = when { + path.startsWith("//localhost/") -> path.removePrefix("//localhost") + path.startsWith("///") -> path.drop(2) + path.startsWith("//") -> path.drop(1) + path.startsWith("/") -> path + else -> "/$path" + } return percentDecode(normalized) } @@ -189,7 +194,7 @@ public object DataSourceUriParser { return chars.concatToString() } - private const val FILE_URI_PREFIX = "file://" + private const val FILE_URI_SCHEME = "file:" private const val HTTP_PREFIX = "http://" private const val HTTPS_PREFIX = "https://" private const val HF_HTTPS_PREFIX = "hf+https://" diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt index 9076c5be..7bcef66f 100644 --- a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt @@ -16,6 +16,15 @@ class DataSourceUriParserTest { assertEquals("train-images.idx", parsed.filename) } + @Test + fun parsesJvmFileUri() { + val parsed = DataSourceUriParser.parse("file:/tmp/skainet/train-images.idx") + + assertEquals(DataSourceProvider.File, parsed.provider) + assertEquals("/tmp/skainet/train-images.idx", parsed.localPath) + assertEquals("train-images.idx", parsed.filename) + } + @Test fun parsesPlainPathAsFile() { val parsed = DataSourceUriParser.parse("fixtures/mnist/train-labels.idx") diff --git a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt new file mode 100644 index 00000000..7cba502a --- /dev/null +++ b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt @@ -0,0 +1,169 @@ +package sk.ainet.data.source + +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.request.get +import io.ktor.client.request.header +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import java.io.File +import java.security.MessageDigest + +/** + * Fetches a remote URI into memory. Kept injectable so tests and applications + * can provide their own HTTP stack or policy layer. + */ +public fun interface RemoteDataSourceFetcher { + public suspend fun fetch(uri: String, headers: Map): ByteArray +} + +/** + * Ktor/CIO-backed remote fetcher for JVM data artifacts. + */ +public class KtorRemoteDataSourceFetcher( + private val client: HttpClient = HttpClient(CIO) { + expectSuccess = true + install(HttpTimeout) { + requestTimeoutMillis = 600_000 + connectTimeoutMillis = 60_000 + socketTimeoutMillis = 600_000 + } + } +) : RemoteDataSourceFetcher, AutoCloseable { + override suspend fun fetch(uri: String, headers: Map): ByteArray { + return client.get(uri) { + headers.forEach { (name, value) -> header(name, value) } + }.body() + } + + override fun close() { + client.close() + } +} + +/** + * JVM resolver for local files and cached remote artifacts. + */ +public class JvmDataSourceResolver( + private val cacheDir: File = defaultCacheDir(), + private val fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher() +) : DataSourceResolver { + override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact = withContext(Dispatchers.IO) { + val parsed = DataSourceUriParser.parse(request.uri) + when (parsed.provider) { + DataSourceProvider.File -> resolveFile(request, parsed) + DataSourceProvider.Http, DataSourceProvider.HuggingFace -> resolveRemote(request, parsed) + } + } + + private fun resolveFile( + request: DataSourceRequest, + parsed: ParsedDataSourceUri + ): DataSourceArtifact { + val path = parsed.localPath ?: throw DataSourceException("File source has no local path: ${request.uri}") + val file = File(path) + require(file.exists()) { "Data source file not found: ${file.absolutePath}" } + require(file.isFile) { "Data source path is not a file: ${file.absolutePath}" } + request.expectedSha256?.let { verifySha256(file.readBytes(), it, request.uri) } + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = file.absolutePath, + sizeBytes = file.length(), + cacheHit = true, + byteReader = { file.readBytes() } + ) + } + + private suspend fun resolveRemote( + request: DataSourceRequest, + parsed: ParsedDataSourceUri + ): DataSourceArtifact { + val target = File(cacheDir, parsed.cacheKey) + val canUseCache = request.cachePolicy == CachePolicy.Use || request.cachePolicy == CachePolicy.Offline + if (canUseCache && target.exists() && target.isFile) { + request.expectedSha256?.let { verifySha256(target.readBytes(), it, request.uri) } + return cachedArtifact(request, parsed, target, cacheHit = true) + } + + if (request.cachePolicy == CachePolicy.Offline) { + throw DataSourceException("No cached artifact available for offline source: ${request.uri}") + } + + val bytes = fetcher.fetch(parsed.transportUri, requestHeaders(request, parsed)) + request.expectedSha256?.let { verifySha256(bytes, it, request.uri) } + + if (request.cachePolicy == CachePolicy.Bypass) { + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = null, + sizeBytes = bytes.size.toLong(), + cacheHit = false, + byteReader = { bytes } + ) + } + + cacheDir.mkdirs() + val temp = File(cacheDir, "${parsed.cacheKey}.tmp") + temp.writeBytes(bytes) + if (!temp.renameTo(target)) { + temp.copyTo(target, overwrite = true) + temp.delete() + } + return cachedArtifact(request, parsed, target, cacheHit = false) + } + + private fun cachedArtifact( + request: DataSourceRequest, + parsed: ParsedDataSourceUri, + target: File, + cacheHit: Boolean + ): DataSourceArtifact { + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = target.absolutePath, + sizeBytes = target.length(), + cacheHit = cacheHit, + byteReader = { target.readBytes() } + ) + } + + private fun requestHeaders( + request: DataSourceRequest, + parsed: ParsedDataSourceUri + ): Map { + if (parsed.provider != DataSourceProvider.HuggingFace) return request.headers + if (request.headers.keys.any { it.equals("Authorization", ignoreCase = true) }) return request.headers + val token = System.getenv("HF_TOKEN") + ?.takeIf { it.isNotBlank() } + ?: System.getenv("HUGGING_FACE_HUB_TOKEN")?.takeIf { it.isNotBlank() } + ?: return request.headers + return request.headers + ("Authorization" to "Bearer $token") + } + + private fun verifySha256(bytes: ByteArray, expected: String, uri: String) { + val actual = MessageDigest.getInstance("SHA-256") + .digest(bytes) + .joinToString("") { byte -> "%02x".format(byte) } + if (!actual.equals(expected, ignoreCase = true)) { + throw DataSourceException( + "SHA-256 mismatch for $uri: expected ${expected.lowercase()}, actual $actual" + ) + } + } + + public companion object { + public fun defaultCacheDir(): File { + val userHome = System.getProperty("user.home")?.takeIf { it.isNotBlank() } + val base = userHome ?: System.getProperty("java.io.tmpdir") + return File(base, ".cache/skainet/data") + } + } +} diff --git a/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt new file mode 100644 index 00000000..5b44523e --- /dev/null +++ b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt @@ -0,0 +1,162 @@ +package sk.ainet.data.source + +import kotlinx.coroutines.test.runTest +import java.nio.file.Files +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class JvmDataSourceResolverTest { + @Test + fun resolvesLocalFileUri() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val file = root.resolve("sample.txt") + file.writeText("hello") + val resolver = JvmDataSourceResolver(cacheDir = root.resolve("cache")) + + val artifact = resolver.resolve(DataSourceRequest(file.toURI().toString())) + + assertEquals("sample.txt", artifact.filename) + assertEquals(file.absolutePath, artifact.localPath) + assertTrue(artifact.cacheHit) + assertContentEquals("hello".encodeToByteArray(), artifact.readBytes()) + } finally { + root.deleteRecursively() + } + } + + @Test + fun cachesRemoteArtifacts() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = FakeFetcher("first".encodeToByteArray()) + val resolver = JvmDataSourceResolver(cacheDir = root.resolve("cache"), fetcher = fetcher) + val request = DataSourceRequest( + "hf+https://huggingface.co/example/model/resolve/main/config.json" + ) + + val first = resolver.resolve(request) + val second = resolver.resolve(request) + + assertEquals(1, fetcher.calls) + assertFalse(first.cacheHit) + assertTrue(second.cacheHit) + assertNotNull(second.localPath) + assertContentEquals("first".encodeToByteArray(), second.readBytes()) + } finally { + root.deleteRecursively() + } + } + + @Test + fun refreshFetchesAgain() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = QueueFetcher( + "old".encodeToByteArray(), + "new".encodeToByteArray() + ) + val resolver = JvmDataSourceResolver(cacheDir = root.resolve("cache"), fetcher = fetcher) + val uri = "https://example.test/data.bin" + + resolver.resolve(DataSourceRequest(uri)).readBytes() + val refreshed = resolver.resolve(DataSourceRequest(uri, cachePolicy = CachePolicy.Refresh)) + + assertEquals(2, fetcher.calls) + assertContentEquals("new".encodeToByteArray(), refreshed.readBytes()) + } finally { + root.deleteRecursively() + } + } + + @Test + fun offlineFailsWhenCacheIsMissing() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val resolver = JvmDataSourceResolver(cacheDir = root.resolve("cache"), fetcher = FakeFetcher(ByteArray(0))) + + assertFailsWith { + resolver.resolve( + DataSourceRequest( + uri = "https://example.test/missing.bin", + cachePolicy = CachePolicy.Offline + ) + ) + } + } finally { + root.deleteRecursively() + } + } + + @Test + fun bypassDoesNotWriteCache() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = FakeFetcher("bytes".encodeToByteArray()) + val cacheDir = root.resolve("cache") + val resolver = JvmDataSourceResolver(cacheDir = cacheDir, fetcher = fetcher) + + val artifact = resolver.resolve( + DataSourceRequest("https://example.test/data.bin", cachePolicy = CachePolicy.Bypass) + ) + + assertEquals(1, fetcher.calls) + assertEquals(null, artifact.localPath) + assertFalse(cacheDir.exists()) + } finally { + root.deleteRecursively() + } + } + + @Test + fun verifiesSha256() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val resolver = JvmDataSourceResolver( + cacheDir = root.resolve("cache"), + fetcher = FakeFetcher("payload".encodeToByteArray()) + ) + + assertFailsWith { + resolver.resolve( + DataSourceRequest( + uri = "https://example.test/payload.bin", + expectedSha256 = "0000" + ) + ) + } + } finally { + root.deleteRecursively() + } + } +} + +private class FakeFetcher( + private val bytes: ByteArray +) : RemoteDataSourceFetcher { + var calls: Int = 0 + private set + + override suspend fun fetch(uri: String, headers: Map): ByteArray { + calls++ + return bytes + } +} + +private class QueueFetcher( + private vararg val responses: ByteArray +) : RemoteDataSourceFetcher { + var calls: Int = 0 + private set + + override suspend fun fetch(uri: String, headers: Map): ByteArray { + val index = calls.coerceAtMost(responses.lastIndex) + calls++ + return responses[index] + } +} From 27841eb764ecd3de2ff30ec794b59a05d50fc894 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 13:41:43 +0200 Subject: [PATCH 07/18] data: route simple loaders through sources --- .../skainet-data-simple/build.gradle.kts | 1 + .../sk/ainet/data/cifar10/CIFAR10Data.kt | 3 +- .../data/fashionmnist/FashionMNISTData.kt | 6 +- .../fashionmnist/FashionMNISTLoaderCommon.kt | 8 +- .../kotlin/sk/ainet/data/mnist/MNISTData.kt | 8 +- .../sk/ainet/data/mnist/MNISTLoaderCommon.kt | 8 +- .../sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt | 72 +++--------- .../fashionmnist/FashionMNISTLoaderJvm.kt | 103 ++++-------------- .../sk/ainet/data/mnist/MNISTLoaderJvm.kt | 103 ++++-------------- .../io/data/cifar10/CIFAR10LoaderTest.kt | 3 +- .../fashionmnist/FashionMNISTLoaderTest.kt | 4 +- .../sk/ainet/io/data/mnist/MNISTLoaderTest.kt | 28 ++++- 12 files changed, 110 insertions(+), 237 deletions(-) diff --git a/skainet-data/skainet-data-simple/build.gradle.kts b/skainet-data/skainet-data-simple/build.gradle.kts index 562bedd6..9bd671f0 100644 --- a/skainet-data/skainet-data-simple/build.gradle.kts +++ b/skainet-data/skainet-data-simple/build.gradle.kts @@ -62,6 +62,7 @@ kotlin { } jvmMain.dependencies { + implementation(project(":skainet-data:skainet-data-source")) implementation(libs.ktor.client.cio) implementation(libs.ktor.client.plugins) implementation(libs.ktor.client.logging) diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt index f2a1e106..9344e150 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt @@ -144,7 +144,8 @@ public data class CIFAR10Dataset( */ public data class CIFAR10LoaderConfig( val cacheDir: String = "cifar10-data", - val useCache: Boolean = true + val useCache: Boolean = true, + val archiveUri: String = CIFAR10Constants.DOWNLOAD_URL ) /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt index b9227418..bdc0e99e 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt @@ -146,7 +146,11 @@ public data class FashionMNISTDataset( */ public data class FashionMNISTLoaderConfig( val cacheDir: String = "fashion-mnist-data", - val useCache: Boolean = true + val useCache: Boolean = true, + val trainImagesUri: String = FashionMNISTConstants.TRAIN_IMAGES_URL, + val trainLabelsUri: String = FashionMNISTConstants.TRAIN_LABELS_URL, + val testImagesUri: String = FashionMNISTConstants.TEST_IMAGES_URL, + val testLabelsUri: String = FashionMNISTConstants.TEST_LABELS_URL ) /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderCommon.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderCommon.kt index 45888031..9fff81fd 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderCommon.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderCommon.kt @@ -16,11 +16,11 @@ public abstract class FashionMNISTLoaderCommon(public val config: FashionMNISTLo */ override suspend fun loadTrainingData(): FashionMNISTDataset { val imagesBytes = downloadAndCacheFile( - FashionMNISTConstants.TRAIN_IMAGES_URL, + config.trainImagesUri, FashionMNISTConstants.TRAIN_IMAGES_FILENAME ) val labelsBytes = downloadAndCacheFile( - FashionMNISTConstants.TRAIN_LABELS_URL, + config.trainLabelsUri, FashionMNISTConstants.TRAIN_LABELS_FILENAME ) @@ -34,11 +34,11 @@ public abstract class FashionMNISTLoaderCommon(public val config: FashionMNISTLo */ override suspend fun loadTestData(): FashionMNISTDataset { val imagesBytes = downloadAndCacheFile( - FashionMNISTConstants.TEST_IMAGES_URL, + config.testImagesUri, FashionMNISTConstants.TEST_IMAGES_FILENAME ) val labelsBytes = downloadAndCacheFile( - FashionMNISTConstants.TEST_LABELS_URL, + config.testLabelsUri, FashionMNISTConstants.TEST_LABELS_FILENAME ) diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt index d120cdd6..c4bfa76a 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt @@ -124,7 +124,11 @@ public data class MNISTDataset( */ public data class MNISTLoaderConfig( val cacheDir: String = "mnist-data", - val useCache: Boolean = true + val useCache: Boolean = true, + val trainImagesUri: String = MNISTConstants.TRAIN_IMAGES_URL, + val trainLabelsUri: String = MNISTConstants.TRAIN_LABELS_URL, + val testImagesUri: String = MNISTConstants.TEST_IMAGES_URL, + val testLabelsUri: String = MNISTConstants.TEST_LABELS_URL ) /** @@ -164,4 +168,4 @@ public interface MNISTLoader { * @return The MNIST test dataset. */ public suspend fun loadTestData(): MNISTDataset -} \ No newline at end of file +} diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTLoaderCommon.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTLoaderCommon.kt index 4c334d31..66ef8085 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTLoaderCommon.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTLoaderCommon.kt @@ -14,11 +14,11 @@ public abstract class MNISTLoaderCommon(public val config: MNISTLoaderConfig) : */ override suspend fun loadTrainingData(): MNISTDataset { val imagesBytes = downloadAndCacheFile( - MNISTConstants.TRAIN_IMAGES_URL, + config.trainImagesUri, MNISTConstants.TRAIN_IMAGES_FILENAME ) val labelsBytes = downloadAndCacheFile( - MNISTConstants.TRAIN_LABELS_URL, + config.trainLabelsUri, MNISTConstants.TRAIN_LABELS_FILENAME ) @@ -32,11 +32,11 @@ public abstract class MNISTLoaderCommon(public val config: MNISTLoaderConfig) : */ override suspend fun loadTestData(): MNISTDataset { val imagesBytes = downloadAndCacheFile( - MNISTConstants.TEST_IMAGES_URL, + config.testImagesUri, MNISTConstants.TEST_IMAGES_FILENAME ) val labelsBytes = downloadAndCacheFile( - MNISTConstants.TEST_LABELS_URL, + config.testLabelsUri, MNISTConstants.TEST_LABELS_FILENAME ) diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt index 82212e3f..c6ddcce5 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt @@ -1,18 +1,13 @@ package sk.ainet.data.cifar10 -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.logging.Logging -import io.ktor.client.plugins.HttpTimeout -import io.ktor.client.request.get -import io.ktor.client.statement.HttpResponse -import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceRequest +import sk.ainet.data.source.JvmDataSourceResolver +import java.io.ByteArrayInputStream import java.io.File -import java.io.FileInputStream import java.io.FileOutputStream -import java.io.ByteArrayInputStream import java.util.zip.GZIPInputStream /** @@ -23,6 +18,7 @@ import java.util.zip.GZIPInputStream * @property config The configuration for the CIFAR-10 loader. */ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon(config) { + private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources")) /** * Downloads the CIFAR-10 archive and extracts the specified batch file. @@ -45,21 +41,16 @@ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon return@withContext batchFile.readBytes() } - // Check if we need to download and extract the archive + // Check if we need to resolve and extract the archive if (!extractedDir.exists() || !config.useCache) { - val archiveFile = File(cacheDir, CIFAR10Constants.ARCHIVE_FILENAME) - - // Download if not cached - if (!archiveFile.exists() || !config.useCache) { - println("Downloading CIFAR-10 archive: ${CIFAR10Constants.DOWNLOAD_URL}") - downloadFile(CIFAR10Constants.DOWNLOAD_URL, archiveFile.path) - } else { - println("Using cached archive: ${archiveFile.path}") - } - - // Extract the archive + val archive = resolver.resolve( + DataSourceRequest( + uri = config.archiveUri, + cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh + ) + ) println("Extracting CIFAR-10 archive...") - extractTarGz(archiveFile.path, cacheDir.path) + extractTarGz(archive.readBytes(), cacheDir.path) } if (!batchFile.exists()) { @@ -69,48 +60,17 @@ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon return@withContext batchFile.readBytes() } - /** - * Downloads a file from a URL. - * - * @param url The URL to download from. - * @param outputPath The path to save the file to. - */ - private suspend fun downloadFile(url: String, outputPath: String) { - val client = HttpClient(CIO) { - install(Logging) - - // Configure timeout for large files (CIFAR-10 is ~170MB) - install(HttpTimeout) { - requestTimeoutMillis = 600000 // 10 minutes - connectTimeoutMillis = 60000 // 60 seconds - socketTimeoutMillis = 600000 // 10 minutes - } - } - - try { - val file = File(outputPath) - - val httpResponse: HttpResponse = client.get(url) - val responseBody: ByteArray = httpResponse.body() - file.writeBytes(responseBody) - - println("File saved to ${file.path} (${responseBody.size} bytes)") - } finally { - client.close() - } - } - /** * Extracts a .tar.gz archive using a simple TAR parser. * - * @param archivePath The path to the .tar.gz file. + * @param archiveBytes The bytes of the .tar.gz file. * @param outputDir The directory to extract files to. */ - private fun extractTarGz(archivePath: String, outputDir: String) { + private fun extractTarGz(archiveBytes: ByteArray, outputDir: String) { val outputDirFile = File(outputDir) // First, decompress gzip to get the tar content - val tarBytes = GZIPInputStream(FileInputStream(archivePath)).use { gzipIn -> + val tarBytes = GZIPInputStream(ByteArrayInputStream(archiveBytes)).use { gzipIn -> gzipIn.readBytes() } diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt index 2444a779..332dfeb2 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt @@ -1,18 +1,13 @@ package sk.ainet.data.fashionmnist -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.logging.Logging -import io.ktor.client.plugins.HttpTimeout -import io.ktor.client.request.get -import io.ktor.client.statement.HttpResponse -import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext -import java.io.File -import java.io.FileInputStream -import java.io.FileOutputStream +import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceRequest +import sk.ainet.data.source.JvmDataSourceResolver +import java.io.ByteArrayInputStream import java.util.zip.GZIPInputStream +import java.io.File /** * JVM implementation of the Fashion-MNIST loader. @@ -20,91 +15,32 @@ import java.util.zip.GZIPInputStream * @property config The configuration for the Fashion-MNIST loader. */ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMNISTLoaderCommon(config) { + private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources")) /** - * Downloads and caches a file. + * Resolves, caches, and decompresses a file when needed. * * @param url The URL to download from. * @param filename The name of the file to save. * @return The bytes of the decompressed file. */ override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.IO) { - val cacheDir = File(config.cacheDir) - if (!cacheDir.exists()) { - cacheDir.mkdirs() - } - - val gzipFile = File(cacheDir, filename) - val decompressedFile = File(cacheDir, filename.removeSuffix(".gz")) - - // Check if the decompressed file already exists in cache - if (config.useCache && decompressedFile.exists()) { - println("Using cached file: ${decompressedFile.path}") - return@withContext decompressedFile.readBytes() - } - - // Check if the gzip file already exists in cache - if (!gzipFile.exists() || !config.useCache) { - println("Downloading Fashion-MNIST file: $url") - downloadFile(url, gzipFile.path) - } else { - println("Using cached gzip file: ${gzipFile.path}") - } - - // Decompress the gzip file - println("Decompressing file: ${gzipFile.path}") - decompressGzipFile(gzipFile.path, decompressedFile.path) - - return@withContext decompressedFile.readBytes() + val artifact = resolver.resolve( + DataSourceRequest( + uri = url, + cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh + ) + ) + return@withContext maybeGunzip(artifact.readBytes()) } - /** - * Downloads a file from a URL. - * - * @param url The URL to download from. - * @param outputPath The path to save the file to. - */ - private suspend fun downloadFile(url: String, outputPath: String) { - val client = HttpClient(CIO) { - install(Logging) - - // Configure timeout for large files - install(HttpTimeout) { - requestTimeoutMillis = 60000 // 60 seconds - connectTimeoutMillis = 60000 // 60 seconds - socketTimeoutMillis = 60000 // 60 seconds - } - } - - try { - val file = File(outputPath) - - val httpResponse: HttpResponse = client.get(url) - val responseBody: ByteArray = httpResponse.body() - file.writeBytes(responseBody) - - println("File saved to ${file.path}") - } finally { - client.close() - } + private fun maybeGunzip(bytes: ByteArray): ByteArray { + if (!bytes.isGzip()) return bytes + return GZIPInputStream(ByteArrayInputStream(bytes)).use { it.readBytes() } } - /** - * Decompresses a gzip file. - * - * @param gzipFilePath The path to the gzip file. - * @param outputFilePath The path to save the decompressed file to. - */ - private fun decompressGzipFile(gzipFilePath: String, outputFilePath: String) { - GZIPInputStream(FileInputStream(gzipFilePath)).use { gzipInputStream -> - FileOutputStream(outputFilePath).use { outputStream -> - val buffer = ByteArray(1024) - var len: Int - while (gzipInputStream.read(buffer).also { len = it } > 0) { - outputStream.write(buffer, 0, len) - } - } - } + private fun ByteArray.isGzip(): Boolean { + return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() } public companion object { @@ -137,4 +73,5 @@ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMN return FashionMNISTLoaderJvm(config) } } + } diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt index b6bcb9aa..330197d2 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt @@ -1,18 +1,13 @@ package sk.ainet.data.mnist -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.logging.Logging -import io.ktor.client.plugins.HttpTimeout -import io.ktor.client.request.get -import io.ktor.client.statement.HttpResponse -import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext -import java.io.File -import java.io.FileInputStream -import java.io.FileOutputStream +import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceRequest +import sk.ainet.data.source.JvmDataSourceResolver +import java.io.ByteArrayInputStream import java.util.zip.GZIPInputStream +import java.io.File /** * JVM implementation of the MNIST loader. @@ -20,91 +15,32 @@ import java.util.zip.GZIPInputStream * @property config The configuration for the MNIST loader. */ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(config) { + private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources")) /** - * Downloads and caches a file. + * Resolves, caches, and decompresses a file when needed. * * @param url The URL to download from. * @param filename The name of the file to save. * @return The bytes of the decompressed file. */ override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.IO) { - val cacheDir = File(config.cacheDir) - if (!cacheDir.exists()) { - cacheDir.mkdirs() - } - - val gzipFile = File(cacheDir, filename) - val decompressedFile = File(cacheDir, filename.removeSuffix(".gz")) - - // Check if the decompressed file already exists in cache - if (config.useCache && decompressedFile.exists()) { - println("Using cached file: ${decompressedFile.path}") - return@withContext decompressedFile.readBytes() - } - - // Check if the gzip file already exists in cache - if (!gzipFile.exists() || !config.useCache) { - println("Downloading file: $url") - downloadFile(url, gzipFile.path) - } else { - println("Using cached gzip file: ${gzipFile.path}") - } - - // Decompress the gzip file - println("Decompressing file: ${gzipFile.path}") - decompressGzipFile(gzipFile.path, decompressedFile.path) - - return@withContext decompressedFile.readBytes() + val artifact = resolver.resolve( + DataSourceRequest( + uri = url, + cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh + ) + ) + return@withContext maybeGunzip(artifact.readBytes()) } - /** - * Downloads a file from a URL. - * - * @param url The URL to download from. - * @param outputPath The path to save the file to. - */ - private suspend fun downloadFile(url: String, outputPath: String) { - val client = HttpClient(CIO) { - install(Logging) - - // Configure timeout for large files - install(HttpTimeout) { - requestTimeoutMillis = 300000 // 5 minutes - connectTimeoutMillis = 60000 // 60 seconds - socketTimeoutMillis = 300000 // 5 minutes - } - } - - try { - val file = File(outputPath) - - val httpResponse: HttpResponse = client.get(url) - val responseBody: ByteArray = httpResponse.body() - file.writeBytes(responseBody) - - println("File saved to ${file.path}") - } finally { - client.close() - } + private fun maybeGunzip(bytes: ByteArray): ByteArray { + if (!bytes.isGzip()) return bytes + return GZIPInputStream(ByteArrayInputStream(bytes)).use { it.readBytes() } } - /** - * Decompresses a gzip file. - * - * @param gzipFilePath The path to the gzip file. - * @param outputFilePath The path to save the decompressed file to. - */ - private fun decompressGzipFile(gzipFilePath: String, outputFilePath: String) { - GZIPInputStream(FileInputStream(gzipFilePath)).use { gzipInputStream -> - FileOutputStream(outputFilePath).use { outputStream -> - val buffer = ByteArray(1024) - var len: Int - while (gzipInputStream.read(buffer).also { len = it } > 0) { - outputStream.write(buffer, 0, len) - } - } - } + private fun ByteArray.isGzip(): Boolean { + return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() } public companion object { @@ -137,4 +73,5 @@ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(confi return MNISTLoaderJvm(config) } } + } diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt index 06b23b39..541516b3 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt @@ -60,7 +60,8 @@ class CIFAR10LoaderTest { fun testLoaderConfiguration() { val config = CIFAR10LoaderConfig( cacheDir = "custom-cache-dir", - useCache = false + useCache = false, + archiveUri = "hf+https://huggingface.co/datasets/cifar10/resolve/main/cifar-10-binary.tar.gz" ) val loader = createCIFAR10Loader(config) diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt index 5f80dc31..30135890 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt @@ -60,7 +60,9 @@ class FashionMNISTLoaderTest { fun testLoaderConfiguration() { val config = FashionMNISTLoaderConfig( cacheDir = "custom-cache-dir", - useCache = false + useCache = false, + trainImagesUri = "file:///datasets/fashion-mnist/train-images", + trainLabelsUri = "hf+https://huggingface.co/datasets/zalando-datasets/fashion_mnist/resolve/main/train-labels" ) val loader = createFashionMNISTLoader(config) diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt index 82e04d75..94796645 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt @@ -6,6 +6,7 @@ import sk.ainet.data.mnist.MNISTImage import sk.ainet.data.mnist.MNISTLoaderConfig import sk.ainet.data.mnist.MNISTLoaderFactory import sk.ainet.data.mnist.MNISTLoaderCommon +import java.nio.file.Files import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNotNull @@ -60,12 +61,37 @@ class MNISTLoaderTest { fun testLoaderConfiguration() { val config = MNISTLoaderConfig( cacheDir = "custom-cache-dir", - useCache = false + useCache = false, + trainImagesUri = "file:///datasets/mnist/train-images", + trainLabelsUri = "hf+https://huggingface.co/datasets/mnist/mnist/resolve/main/train-labels" ) val loader = MNISTLoaderFactory.create(config) assertNotNull(loader) } + + @Test + fun testJvmLoaderReadsConfiguredFileUris() = runBlocking { + val root = Files.createTempDirectory("skainet-mnist-loader-test").toFile() + try { + val trainImages = root.resolve("train-images.idx") + val trainLabels = root.resolve("train-labels.idx") + trainImages.writeBytes(TRAINING_IMAGES_BYTES) + trainLabels.writeBytes(TRAINING_LABELS_BYTES) + val config = MNISTLoaderConfig( + cacheDir = root.resolve("cache").absolutePath, + useCache = false, + trainImagesUri = trainImages.toURI().toString(), + trainLabelsUri = trainLabels.toURI().toString() + ) + + val dataset = MNISTLoaderFactory.create(config).loadTrainingData() + + assertEquals(EXPECTED_TRAINING_DATA, dataset.images) + } finally { + root.deleteRecursively() + } + } } private fun createFakeLoader(config: MNISTLoaderConfig = MNISTLoaderConfig()): FakeMNISTLoader { From 75ac4606a998f6c9133b09c6d29dfa5bbc32eeed Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 13:43:08 +0200 Subject: [PATCH 08/18] docs: explain data source URIs --- README.md | 1 + build.gradle.kts | 3 +- docs/modules/ROOT/nav.adoc | 1 + .../data-sources-getting-started.adoc | 117 ++++++++++++++++++ 4 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc diff --git a/README.md b/README.md index 59beabfe..56eda582 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,7 @@ Runnable examples: ### Data and I/O - Built-in loaders: MNIST, Fashion-MNIST, CIFAR-10 +- URI-backed data sources: `file://`, `https://`, `hf+https://`, and `hf://...` - Formats: GGUF, ONNX, SafeTensors, JSON, Image (JPEG, PNG) - Type-safe transform DSL: resize, crop, normalize, toTensor diff --git a/build.gradle.kts b/build.gradle.kts index 771b2133..d76ff097 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -151,6 +151,7 @@ dependencies { // skainet-data dokka(project(":skainet-data:skainet-data-api")) + dokka(project(":skainet-data:skainet-data-source")) dokka(project(":skainet-data:skainet-data-transform")) dokka(project(":skainet-data:skainet-data-simple")) dokka(project(":skainet-data:skainet-data-media")) @@ -178,4 +179,4 @@ tasks.register("bundleDokkaIntoSite") { dependsOn("dokkaGenerate") from(layout.buildDirectory.dir("dokka/html")) into(layout.projectDirectory.dir("docs/build/site/api")) -} \ No newline at end of file +} diff --git a/docs/modules/ROOT/nav.adoc b/docs/modules/ROOT/nav.adoc index 1c8ed540..e6a714cd 100644 --- a/docs/modules/ROOT/nav.adoc +++ b/docs/modules/ROOT/nav.adoc @@ -5,6 +5,7 @@ * Tutorials ** xref:tutorials/kotlin-getting-started.adoc[Kotlin getting started] ** xref:tutorials/java-getting-started.adoc[Java getting started] +** xref:tutorials/data-sources-getting-started.adoc[Data sources and Hugging Face] ** xref:tutorials/image-data-getting-started.adoc[Image and data API] ** xref:tutorials/hlo-getting-started.adoc[StableHLO getting started] ** xref:tutorials/minerva-getting-started.adoc[Minerva getting started] diff --git a/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc new file mode 100644 index 00000000..b4001501 --- /dev/null +++ b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc @@ -0,0 +1,117 @@ +== Data sources and Hugging Face + +SKaiNET separates artifact resolution from dataset parsing and preprocessing. +Use `skainet-data-source` when a dataset, tokenizer, model sidecar, or fixture +can live either on disk or behind a remote URI. + +[cols="1,3",options="header"] +|=== +| URI form | Meaning +| `file:///path/to/file` +| Read a local file. + +| `https://host/path/file` +| Download and cache a generic remote artifact. + +| `hf+https://huggingface.co/org/repo/resolve/main/file` +| Treat a Hugging Face resolve URL as a Hugging Face artifact. + +| `hf://org/repo@main/path/file` +| Expand to a Hugging Face model repository resolve URL. + +| `hf://datasets/org/repo@main/path/file` +| Expand to a Hugging Face dataset repository resolve URL. +|=== + +=== Add the modules + +For JVM consumers, add the source module beside the data loaders you use: + +[source,kotlin] +---- +dependencies { + implementation(platform("sk.ainet:skainet-bom:0.32.4")) + + implementation("sk.ainet.core:skainet-data-source-jvm") + implementation("sk.ainet.core:skainet-data-simple-jvm") +} +---- + +=== Resolve one artifact + +`JvmDataSourceResolver` materializes remote artifacts into a cache and returns +a `DataSourceArtifact` that can be read as bytes. Public Hugging Face files do +not need credentials. Private files can use an `Authorization` header, or the +JVM resolver will read `HF_TOKEN` / `HUGGING_FACE_HUB_TOKEN` from the +environment when the URI provider is Hugging Face. + +[source,kotlin] +---- +import sk.ainet.data.source.DataSourceRequest +import sk.ainet.data.source.JvmDataSourceResolver + +val resolver = JvmDataSourceResolver() +val artifact = resolver.resolve( + DataSourceRequest( + uri = "hf+https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json" + ) +) + +println(artifact.filename) +println(artifact.localPath) + +val bytes = artifact.readBytes() +---- + +=== Use sources with built-in loaders + +MNIST and Fashion-MNIST expose per-file URI overrides. CIFAR-10 exposes an +archive URI override. Defaults still point to the historical public dataset +locations, so existing code keeps working. + +[source,kotlin] +---- +import sk.ainet.data.mnist.MNIST +import sk.ainet.data.mnist.MNISTLoaderConfig + +val train = MNIST.loadTrain( + MNISTLoaderConfig( + trainImagesUri = "file:///datasets/mnist/train-images-idx3-ubyte", + trainLabelsUri = "hf+https://huggingface.co/your-org/mnist-idx/resolve/main/train-labels-idx1-ubyte.gz" + ) +) + +val batches = train.batchIterator(batchSize = 64) +---- + +=== Cache behavior + +Use `CachePolicy.Use` for normal operation, `Refresh` to re-download, +`Offline` to require a cached copy, and `Bypass` to avoid writing the cache. +Built-in JVM loaders map `useCache = true` to `Use` and `useCache = false` +to `Refresh`. + +[source,kotlin] +---- +import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceRequest + +val refreshed = resolver.resolve( + DataSourceRequest( + uri = "hf://datasets/your-org/your-dataset@main/data/train-00000.parquet", + cachePolicy = CachePolicy.Refresh + ) +) +---- + +=== Keep preprocessing separate + +After bytes are parsed into a dataset, continue using the existing transform +DSL for image/tensor preprocessing: + +[source,kotlin] +---- +import sk.ainet.data.transform.mnistPreprocessing + +val preprocessing = mnistPreprocessing(ctx) +---- From 130702ff82833d0e3931e318c721fb85e950ca6d Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 14:22:21 +0200 Subject: [PATCH 09/18] data: share source resolver core --- .../sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt | 26 +-- .../data/common/JvmDatasetSourceReader.kt | 40 +++++ .../fashionmnist/FashionMNISTLoaderJvm.kt | 30 +--- .../sk/ainet/data/mnist/MNISTLoaderJvm.kt | 30 +--- .../skainet-data-source/build.gradle.kts | 4 +- .../data/source/DefaultDataSourceResolver.kt | 138 +++++++++++++++ .../source/DefaultDataSourceResolverTest.kt | 164 ++++++++++++++++++ .../data/source/JvmDataSourceResolver.kt | 144 +++++---------- 8 files changed, 404 insertions(+), 172 deletions(-) create mode 100644 skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt create mode 100644 skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt create mode 100644 skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt index c6ddcce5..10fc06a3 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt @@ -2,13 +2,10 @@ package sk.ainet.data.cifar10 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext -import sk.ainet.data.source.CachePolicy -import sk.ainet.data.source.DataSourceRequest -import sk.ainet.data.source.JvmDataSourceResolver -import java.io.ByteArrayInputStream +import sk.ainet.data.common.JvmDatasetSourceReader +import sk.ainet.data.common.gunzip import java.io.File import java.io.FileOutputStream -import java.util.zip.GZIPInputStream /** * JVM implementation of the CIFAR-10 loader. @@ -18,7 +15,7 @@ import java.util.zip.GZIPInputStream * @property config The configuration for the CIFAR-10 loader. */ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon(config) { - private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources")) + private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache) /** * Downloads the CIFAR-10 archive and extracts the specified batch file. @@ -43,14 +40,8 @@ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon // Check if we need to resolve and extract the archive if (!extractedDir.exists() || !config.useCache) { - val archive = resolver.resolve( - DataSourceRequest( - uri = config.archiveUri, - cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh - ) - ) println("Extracting CIFAR-10 archive...") - extractTarGz(archive.readBytes(), cacheDir.path) + extractTarGz(sources.read(config.archiveUri), cacheDir.path) } if (!batchFile.exists()) { @@ -68,14 +59,7 @@ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon */ private fun extractTarGz(archiveBytes: ByteArray, outputDir: String) { val outputDirFile = File(outputDir) - - // First, decompress gzip to get the tar content - val tarBytes = GZIPInputStream(ByteArrayInputStream(archiveBytes)).use { gzipIn -> - gzipIn.readBytes() - } - - // Parse the TAR archive - extractTar(tarBytes, outputDirFile) + extractTar(archiveBytes.gunzip(), outputDirFile) } /** diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt new file mode 100644 index 00000000..29ce6bbf --- /dev/null +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt @@ -0,0 +1,40 @@ +package sk.ainet.data.common + +import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceRequest +import sk.ainet.data.source.JvmDataSourceResolver +import java.io.ByteArrayInputStream +import java.io.File +import java.util.zip.GZIPInputStream + +internal class JvmDatasetSourceReader( + cacheDir: String, + useCache: Boolean +) { + private val resolver = JvmDataSourceResolver(File(cacheDir, "sources")) + private val cachePolicy = if (useCache) CachePolicy.Use else CachePolicy.Refresh + + suspend fun read(uri: String): ByteArray { + val artifact = resolver.resolve( + DataSourceRequest( + uri = uri, + cachePolicy = cachePolicy + ) + ) + return artifact.readBytes() + } + + suspend fun readGzipDecoded(uri: String): ByteArray = read(uri).gunzipIfNeeded() +} + +internal fun ByteArray.gunzip(): ByteArray { + return GZIPInputStream(ByteArrayInputStream(this)).use { it.readBytes() } +} + +internal fun ByteArray.gunzipIfNeeded(): ByteArray { + return if (isGzip()) gunzip() else this +} + +private fun ByteArray.isGzip(): Boolean { + return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() +} diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt index 332dfeb2..9b100db9 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt @@ -1,13 +1,6 @@ package sk.ainet.data.fashionmnist -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import sk.ainet.data.source.CachePolicy -import sk.ainet.data.source.DataSourceRequest -import sk.ainet.data.source.JvmDataSourceResolver -import java.io.ByteArrayInputStream -import java.util.zip.GZIPInputStream -import java.io.File +import sk.ainet.data.common.JvmDatasetSourceReader /** * JVM implementation of the Fashion-MNIST loader. @@ -15,7 +8,7 @@ import java.io.File * @property config The configuration for the Fashion-MNIST loader. */ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMNISTLoaderCommon(config) { - private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources")) + private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache) /** * Resolves, caches, and decompresses a file when needed. @@ -24,23 +17,8 @@ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMN * @param filename The name of the file to save. * @return The bytes of the decompressed file. */ - override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.IO) { - val artifact = resolver.resolve( - DataSourceRequest( - uri = url, - cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh - ) - ) - return@withContext maybeGunzip(artifact.readBytes()) - } - - private fun maybeGunzip(bytes: ByteArray): ByteArray { - if (!bytes.isGzip()) return bytes - return GZIPInputStream(ByteArrayInputStream(bytes)).use { it.readBytes() } - } - - private fun ByteArray.isGzip(): Boolean { - return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() + override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray { + return sources.readGzipDecoded(url) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt index 330197d2..e5466b4f 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt @@ -1,13 +1,6 @@ package sk.ainet.data.mnist -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import sk.ainet.data.source.CachePolicy -import sk.ainet.data.source.DataSourceRequest -import sk.ainet.data.source.JvmDataSourceResolver -import java.io.ByteArrayInputStream -import java.util.zip.GZIPInputStream -import java.io.File +import sk.ainet.data.common.JvmDatasetSourceReader /** * JVM implementation of the MNIST loader. @@ -15,7 +8,7 @@ import java.io.File * @property config The configuration for the MNIST loader. */ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(config) { - private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources")) + private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache) /** * Resolves, caches, and decompresses a file when needed. @@ -24,23 +17,8 @@ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(confi * @param filename The name of the file to save. * @return The bytes of the decompressed file. */ - override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.IO) { - val artifact = resolver.resolve( - DataSourceRequest( - uri = url, - cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh - ) - ) - return@withContext maybeGunzip(artifact.readBytes()) - } - - private fun maybeGunzip(bytes: ByteArray): ByteArray { - if (!bytes.isGzip()) return bytes - return GZIPInputStream(ByteArrayInputStream(bytes)).use { it.readBytes() } - } - - private fun ByteArray.isGzip(): Boolean { - return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() + override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray { + return sources.readGzipDecoded(url) } public companion object { diff --git a/skainet-data/skainet-data-source/build.gradle.kts b/skainet-data/skainet-data-source/build.gradle.kts index f8b4dbc0..7e4ee92d 100644 --- a/skainet-data/skainet-data-source/build.gradle.kts +++ b/skainet-data/skainet-data-source/build.gradle.kts @@ -22,6 +22,7 @@ kotlin { commonTest.dependencies { implementation(libs.kotlin.test) + implementation(libs.kotlinx.coroutines.test) } jvmMain.dependencies { @@ -31,8 +32,5 @@ kotlin { implementation(libs.kotlinx.coroutines.core.jvm) } - jvmTest.dependencies { - implementation(libs.kotlinx.coroutines.test) - } } } diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt new file mode 100644 index 00000000..4524221e --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt @@ -0,0 +1,138 @@ +package sk.ainet.data.source + +/** + * Fetches a remote URI into memory. Kept injectable so tests and applications + * can provide their own HTTP stack or policy layer. + */ +public fun interface RemoteDataSourceFetcher { + public suspend fun fetch(uri: String, headers: Map): ByteArray +} + +/** + * Adds platform or application-specific headers to a resolved remote request. + */ +public fun interface DataSourceHeaderProvider { + public fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map +} + +/** + * Computes checksums for integrity verification without tying resolver policy + * to a concrete platform crypto API. + */ +public fun interface DataSourceChecksum { + public fun sha256Hex(bytes: ByteArray): String +} + +/** + * Platform storage adapter used by [DefaultDataSourceResolver]. + */ +public interface DataSourceByteStore { + public suspend fun readLocal(path: String): DataSourceStoredArtifact? + public suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? + public suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact +} + +/** + * A platform materialized artifact used by the common resolver core. + */ +public class DataSourceStoredArtifact( + public val localPath: String?, + public val sizeBytes: Long?, + private val byteReader: suspend () -> ByteArray +) { + public suspend fun readBytes(): ByteArray = byteReader() +} + +/** + * Platform-neutral resolver implementation for local files, HTTP(S), and + * Hugging Face source URIs. Storage, network, auth, and checksum details are + * injected so this policy can be reused by each KMP target. + */ +public class DefaultDataSourceResolver( + private val store: DataSourceByteStore, + private val fetcher: RemoteDataSourceFetcher, + private val checksum: DataSourceChecksum, + private val headerProvider: DataSourceHeaderProvider = DataSourceHeaderProvider { request, _ -> + request.headers + } +) : DataSourceResolver { + override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact { + val parsed = DataSourceUriParser.parse(request.uri) + return when (parsed.provider) { + DataSourceProvider.File -> resolveFile(request, parsed) + DataSourceProvider.Http, DataSourceProvider.HuggingFace -> resolveRemote(request, parsed) + } + } + + private suspend fun resolveFile( + request: DataSourceRequest, + parsed: ParsedDataSourceUri + ): DataSourceArtifact { + val path = parsed.localPath ?: throw DataSourceException("File source has no local path: ${request.uri}") + val stored = store.readLocal(path) + ?: throw DataSourceException("Data source file not found: $path") + request.expectedSha256?.let { verifySha256(stored.readBytes(), it, request.uri) } + return stored.toArtifact(request, parsed, cacheHit = true) + } + + private suspend fun resolveRemote( + request: DataSourceRequest, + parsed: ParsedDataSourceUri + ): DataSourceArtifact { + val canUseCache = request.cachePolicy == CachePolicy.Use || request.cachePolicy == CachePolicy.Offline + if (canUseCache) { + val cached = store.readCache(parsed.cacheKey) + if (cached != null) { + request.expectedSha256?.let { verifySha256(cached.readBytes(), it, request.uri) } + return cached.toArtifact(request, parsed, cacheHit = true) + } + } + + if (request.cachePolicy == CachePolicy.Offline) { + throw DataSourceException("No cached artifact available for offline source: ${request.uri}") + } + + val bytes = fetcher.fetch(parsed.transportUri, headerProvider.headers(request, parsed)) + request.expectedSha256?.let { verifySha256(bytes, it, request.uri) } + + if (request.cachePolicy == CachePolicy.Bypass) { + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = null, + sizeBytes = bytes.size.toLong(), + cacheHit = false, + byteReader = { bytes } + ) + } + + val stored = store.writeCache(parsed.cacheKey, bytes) + return stored.toArtifact(request, parsed, cacheHit = false) + } + + private suspend fun DataSourceStoredArtifact.toArtifact( + request: DataSourceRequest, + parsed: ParsedDataSourceUri, + cacheHit: Boolean + ): DataSourceArtifact { + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = localPath, + sizeBytes = sizeBytes, + cacheHit = cacheHit, + byteReader = { readBytes() } + ) + } + + private fun verifySha256(bytes: ByteArray, expected: String, uri: String) { + val actual = checksum.sha256Hex(bytes) + if (!actual.equals(expected, ignoreCase = true)) { + throw DataSourceException( + "SHA-256 mismatch for $uri: expected ${expected.lowercase()}, actual $actual" + ) + } + } +} diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt new file mode 100644 index 00000000..1b864465 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt @@ -0,0 +1,164 @@ +package sk.ainet.data.source + +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class DefaultDataSourceResolverTest { + @Test + fun resolvesLocalArtifactsThroughStore() = runTest { + val store = MemoryDataSourceByteStore( + localArtifacts = mapOf("/data/sample.bin" to "local".encodeToByteArray()) + ) + val fetcher = RecordingFetcher("remote".encodeToByteArray()) + val resolver = DefaultDataSourceResolver(store, fetcher, TestChecksum) + + val artifact = resolver.resolve(DataSourceRequest("/data/sample.bin")) + + assertEquals("/data/sample.bin", artifact.localPath) + assertTrue(artifact.cacheHit) + assertEquals(0, fetcher.calls) + assertContentEquals("local".encodeToByteArray(), artifact.readBytes()) + } + + @Test + fun cachesRemoteArtifactsAndReusesThem() = runTest { + val store = MemoryDataSourceByteStore() + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver(store, fetcher, TestChecksum) + val request = DataSourceRequest( + uri = "https://example.test/data.bin", + expectedSha256 = "sha:payload" + ) + + val first = resolver.resolve(request) + val second = resolver.resolve(request) + + assertFalse(first.cacheHit) + assertTrue(second.cacheHit) + assertEquals(1, fetcher.calls) + assertEquals(1, store.cacheWrites) + assertContentEquals("payload".encodeToByteArray(), second.readBytes()) + } + + @Test + fun bypassSkipsPersistentCache() = runTest { + val store = MemoryDataSourceByteStore() + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver(store, fetcher, TestChecksum) + + val artifact = resolver.resolve( + DataSourceRequest( + uri = "https://example.test/data.bin", + cachePolicy = CachePolicy.Bypass + ) + ) + + assertEquals(null, artifact.localPath) + assertFalse(artifact.cacheHit) + assertEquals(1, fetcher.calls) + assertEquals(0, store.cacheWrites) + } + + @Test + fun verifiesChecksumsInCommonCore() = runTest { + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = RecordingFetcher("payload".encodeToByteArray()), + checksum = TestChecksum + ) + + assertFailsWith { + resolver.resolve( + DataSourceRequest( + uri = "https://example.test/data.bin", + expectedSha256 = "sha:other" + ) + ) + } + } + + @Test + fun forwardsProviderHeadersToFetcher() = runTest { + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = fetcher, + checksum = TestChecksum, + headerProvider = DataSourceHeaderProvider { request, parsedUri -> + request.headers + ("X-SKaiNET-Provider" to parsedUri.provider.name) + } + ) + + resolver.resolve( + DataSourceRequest( + uri = "hf://datasets/org/repo@main/file.bin", + headers = mapOf("Accept" to "application/octet-stream") + ) + ) + + assertEquals( + mapOf( + "Accept" to "application/octet-stream", + "X-SKaiNET-Provider" to "HuggingFace" + ), + fetcher.lastHeaders + ) + } +} + +private class MemoryDataSourceByteStore( + private val localArtifacts: Map = emptyMap() +) : DataSourceByteStore { + private val cacheArtifacts = mutableMapOf() + + var cacheWrites: Int = 0 + private set + + override suspend fun readLocal(path: String): DataSourceStoredArtifact? { + return localArtifacts[path]?.storedAt(path) + } + + override suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? { + return cacheArtifacts[cacheKey]?.storedAt("/cache/$cacheKey") + } + + override suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact { + cacheWrites++ + cacheArtifacts[cacheKey] = bytes + return bytes.storedAt("/cache/$cacheKey") + } + + private fun ByteArray.storedAt(path: String): DataSourceStoredArtifact { + val bytes = copyOf() + return DataSourceStoredArtifact( + localPath = path, + sizeBytes = bytes.size.toLong(), + byteReader = { bytes.copyOf() } + ) + } +} + +private class RecordingFetcher( + private val bytes: ByteArray +) : RemoteDataSourceFetcher { + var calls: Int = 0 + private set + + var lastHeaders: Map = emptyMap() + private set + + override suspend fun fetch(uri: String, headers: Map): ByteArray { + calls++ + lastHeaders = headers + return bytes.copyOf() + } +} + +private object TestChecksum : DataSourceChecksum { + override fun sha256Hex(bytes: ByteArray): String = "sha:${bytes.decodeToString()}" +} diff --git a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt index 7cba502a..0dec564d 100644 --- a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt +++ b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt @@ -11,14 +11,6 @@ import kotlinx.coroutines.withContext import java.io.File import java.security.MessageDigest -/** - * Fetches a remote URI into memory. Kept injectable so tests and applications - * can provide their own HTTP stack or policy layer. - */ -public fun interface RemoteDataSourceFetcher { - public suspend fun fetch(uri: String, headers: Map): ByteArray -} - /** * Ktor/CIO-backed remote fetcher for JVM data artifacts. */ @@ -47,99 +39,70 @@ public class KtorRemoteDataSourceFetcher( * JVM resolver for local files and cached remote artifacts. */ public class JvmDataSourceResolver( - private val cacheDir: File = defaultCacheDir(), - private val fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher() + cacheDir: File = defaultCacheDir(), + fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher() ) : DataSourceResolver { - override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact = withContext(Dispatchers.IO) { - val parsed = DataSourceUriParser.parse(request.uri) - when (parsed.provider) { - DataSourceProvider.File -> resolveFile(request, parsed) - DataSourceProvider.Http, DataSourceProvider.HuggingFace -> resolveRemote(request, parsed) - } - } + private val delegate = DefaultDataSourceResolver( + store = JvmFileDataSourceByteStore(cacheDir), + fetcher = fetcher, + checksum = JvmSha256DataSourceChecksum, + headerProvider = JvmHuggingFaceHeaderProvider + ) - private fun resolveFile( - request: DataSourceRequest, - parsed: ParsedDataSourceUri - ): DataSourceArtifact { - val path = parsed.localPath ?: throw DataSourceException("File source has no local path: ${request.uri}") - val file = File(path) - require(file.exists()) { "Data source file not found: ${file.absolutePath}" } - require(file.isFile) { "Data source path is not a file: ${file.absolutePath}" } - request.expectedSha256?.let { verifySha256(file.readBytes(), it, request.uri) } - return DataSourceArtifact( - request = request, - parsedUri = parsed, - filename = parsed.filename, - localPath = file.absolutePath, - sizeBytes = file.length(), - cacheHit = true, - byteReader = { file.readBytes() } - ) + override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact = withContext(Dispatchers.IO) { + delegate.resolve(request) } - private suspend fun resolveRemote( - request: DataSourceRequest, - parsed: ParsedDataSourceUri - ): DataSourceArtifact { - val target = File(cacheDir, parsed.cacheKey) - val canUseCache = request.cachePolicy == CachePolicy.Use || request.cachePolicy == CachePolicy.Offline - if (canUseCache && target.exists() && target.isFile) { - request.expectedSha256?.let { verifySha256(target.readBytes(), it, request.uri) } - return cachedArtifact(request, parsed, target, cacheHit = true) + public companion object { + public fun defaultCacheDir(): File { + val userHome = System.getProperty("user.home")?.takeIf { it.isNotBlank() } + val base = userHome ?: System.getProperty("java.io.tmpdir") + return File(base, ".cache/skainet/data") } + } +} - if (request.cachePolicy == CachePolicy.Offline) { - throw DataSourceException("No cached artifact available for offline source: ${request.uri}") +internal class JvmFileDataSourceByteStore( + private val cacheDir: File +) : DataSourceByteStore { + override suspend fun readLocal(path: String): DataSourceStoredArtifact? { + val file = File(path) + if (!file.exists()) return null + if (!file.isFile) { + throw DataSourceException("Data source path is not a file: ${file.absolutePath}") } + return file.toStoredArtifact() + } - val bytes = fetcher.fetch(parsed.transportUri, requestHeaders(request, parsed)) - request.expectedSha256?.let { verifySha256(bytes, it, request.uri) } - - if (request.cachePolicy == CachePolicy.Bypass) { - return DataSourceArtifact( - request = request, - parsedUri = parsed, - filename = parsed.filename, - localPath = null, - sizeBytes = bytes.size.toLong(), - cacheHit = false, - byteReader = { bytes } - ) - } + override suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? { + val target = File(cacheDir, cacheKey) + return if (target.exists() && target.isFile) target.toStoredArtifact() else null + } + override suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact { cacheDir.mkdirs() - val temp = File(cacheDir, "${parsed.cacheKey}.tmp") + val target = File(cacheDir, cacheKey) + val temp = File(cacheDir, "$cacheKey.tmp") temp.writeBytes(bytes) if (!temp.renameTo(target)) { temp.copyTo(target, overwrite = true) temp.delete() } - return cachedArtifact(request, parsed, target, cacheHit = false) + return target.toStoredArtifact() } - private fun cachedArtifact( - request: DataSourceRequest, - parsed: ParsedDataSourceUri, - target: File, - cacheHit: Boolean - ): DataSourceArtifact { - return DataSourceArtifact( - request = request, - parsedUri = parsed, - filename = parsed.filename, - localPath = target.absolutePath, - sizeBytes = target.length(), - cacheHit = cacheHit, - byteReader = { target.readBytes() } + private fun File.toStoredArtifact(): DataSourceStoredArtifact { + return DataSourceStoredArtifact( + localPath = absolutePath, + sizeBytes = length(), + byteReader = { readBytes() } ) } +} - private fun requestHeaders( - request: DataSourceRequest, - parsed: ParsedDataSourceUri - ): Map { - if (parsed.provider != DataSourceProvider.HuggingFace) return request.headers +internal object JvmHuggingFaceHeaderProvider : DataSourceHeaderProvider { + override fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map { + if (parsedUri.provider != DataSourceProvider.HuggingFace) return request.headers if (request.headers.keys.any { it.equals("Authorization", ignoreCase = true) }) return request.headers val token = System.getenv("HF_TOKEN") ?.takeIf { it.isNotBlank() } @@ -147,23 +110,12 @@ public class JvmDataSourceResolver( ?: return request.headers return request.headers + ("Authorization" to "Bearer $token") } +} - private fun verifySha256(bytes: ByteArray, expected: String, uri: String) { - val actual = MessageDigest.getInstance("SHA-256") +internal object JvmSha256DataSourceChecksum : DataSourceChecksum { + override fun sha256Hex(bytes: ByteArray): String { + return MessageDigest.getInstance("SHA-256") .digest(bytes) .joinToString("") { byte -> "%02x".format(byte) } - if (!actual.equals(expected, ignoreCase = true)) { - throw DataSourceException( - "SHA-256 mismatch for $uri: expected ${expected.lowercase()}, actual $actual" - ) - } - } - - public companion object { - public fun defaultCacheDir(): File { - val userHome = System.getProperty("user.home")?.takeIf { it.isNotBlank() } - val base = userHome ?: System.getProperty("java.io.tmpdir") - return File(base, ".cache/skainet/data") - } } } From 40a1ab7250e60ae18765ccad40c62da4a08008a7 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 15:00:54 +0200 Subject: [PATCH 10/18] data: stream source artifacts with kotlinx-io --- .../data-sources-getting-started.adoc | 14 +- .../skainet-data-source/build.gradle.kts | 1 + .../sk/ainet/data/source/DataSourceModels.kt | 37 ++- .../data/source/DefaultDataSourceResolver.kt | 269 ++++++++++++++++-- .../source/DefaultDataSourceResolverTest.kt | 56 +++- .../data/source/JvmDataSourceResolver.kt | 72 ++--- .../data/source/JvmDataSourceResolverTest.kt | 10 +- 7 files changed, 361 insertions(+), 98 deletions(-) diff --git a/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc index b4001501..34eb72a7 100644 --- a/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc +++ b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc @@ -40,9 +40,9 @@ dependencies { === Resolve one artifact `JvmDataSourceResolver` materializes remote artifacts into a cache and returns -a `DataSourceArtifact` that can be read as bytes. Public Hugging Face files do -not need credentials. Private files can use an `Authorization` header, or the -JVM resolver will read `HF_TOKEN` / `HUGGING_FACE_HUB_TOKEN` from the +a `DataSourceArtifact` that opens a `kotlinx.io.Source`. Public Hugging Face +files do not need credentials. Private files can use an `Authorization` header, +or the JVM resolver will read `HF_TOKEN` / `HUGGING_FACE_HUB_TOKEN` from the environment when the URI provider is Hugging Face. [source,kotlin] @@ -60,6 +60,14 @@ val artifact = resolver.resolve( println(artifact.filename) println(artifact.localPath) +val source = artifact.openSource() +try { + // Pass the source to a parser/loader for model-sized artifacts. +} finally { + source.close() +} + +// Convenience for small sidecars and tests. val bytes = artifact.readBytes() ---- diff --git a/skainet-data/skainet-data-source/build.gradle.kts b/skainet-data/skainet-data-source/build.gradle.kts index 7e4ee92d..7d7c3065 100644 --- a/skainet-data/skainet-data-source/build.gradle.kts +++ b/skainet-data/skainet-data-source/build.gradle.kts @@ -18,6 +18,7 @@ kotlin { sourceSets { commonMain.dependencies { implementation(libs.kotlinx.coroutines) + implementation(libs.kotlinx.io.core) } commonTest.dependencies { diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt index 71eacf32..0e34245d 100644 --- a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt @@ -1,5 +1,9 @@ package sk.ainet.data.source +import kotlinx.io.Sink +import kotlinx.io.Source +import kotlinx.io.readByteArray + /** * Cache behavior requested by a caller resolving a data artifact. */ @@ -80,9 +84,38 @@ public class DataSourceArtifact( public val localPath: String?, public val sizeBytes: Long?, public val cacheHit: Boolean, - private val byteReader: suspend () -> ByteArray + private val sourceOpener: suspend () -> Source ) { - public suspend fun readBytes(): ByteArray = byteReader() + /** + * Opens a fresh source for this artifact. Callers own and must close it. + */ + public suspend fun openSource(): Source = sourceOpener() + + /** + * Convenience for small artifacts. Prefer [openSource] or [copyTo] for + * model-scale data. + */ + public suspend fun readBytes(): ByteArray { + val source = openSource() + return try { + source.readByteArray() + } finally { + source.close() + } + } + + /** + * Streams this artifact into [sink]. The source is closed after copying; + * [sink] is left open for the caller. + */ + public suspend fun copyTo(sink: Sink): Long { + val source = openSource() + return try { + source.transferTo(sink) + } finally { + source.close() + } + } } /** diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt index 4524221e..b7ee217f 100644 --- a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt @@ -1,11 +1,38 @@ package sk.ainet.data.source +import kotlinx.io.Buffer +import kotlinx.io.RawSource +import kotlinx.io.Sink +import kotlinx.io.Source +import kotlinx.io.buffered +import kotlinx.io.files.FileSystem +import kotlinx.io.files.Path +import kotlinx.io.files.SystemFileSystem +import kotlinx.io.readByteArray + +/** + * Remote response body exposed as a one-shot [Source]. + */ +public class DataSourceRemoteContent( + public val source: Source, + public val sizeBytes: Long? = null +) { + public companion object { + public fun fromBytes(bytes: ByteArray): DataSourceRemoteContent { + return DataSourceRemoteContent( + source = bytes.toDataSourceSource(), + sizeBytes = bytes.size.toLong() + ) + } + } +} + /** - * Fetches a remote URI into memory. Kept injectable so tests and applications + * Fetches a remote URI as a stream. Kept injectable so tests and applications * can provide their own HTTP stack or policy layer. */ public fun interface RemoteDataSourceFetcher { - public suspend fun fetch(uri: String, headers: Map): ByteArray + public suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent } /** @@ -19,28 +46,153 @@ public fun interface DataSourceHeaderProvider { * Computes checksums for integrity verification without tying resolver policy * to a concrete platform crypto API. */ -public fun interface DataSourceChecksum { - public fun sha256Hex(bytes: ByteArray): String +public interface DataSourceChecksum { + public fun newSha256(): DataSourceHash +} + +/** + * Incremental hash state used while streaming artifact bytes. + */ +public interface DataSourceHash { + public fun update(bytes: ByteArray, startIndex: Int = 0, endIndex: Int = bytes.size) + public fun hex(): String } /** * Platform storage adapter used by [DefaultDataSourceResolver]. */ -public interface DataSourceByteStore { +public interface DataSourceArtifactStore { public suspend fun readLocal(path: String): DataSourceStoredArtifact? public suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? - public suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact + public suspend fun writeCache( + cacheKey: String, + source: Source, + sizeBytes: Long? = null, + validate: suspend (DataSourceStoredArtifact) -> Unit = {} + ): DataSourceStoredArtifact } /** - * A platform materialized artifact used by the common resolver core. + * A materialized artifact used by the common resolver core. */ public class DataSourceStoredArtifact( public val localPath: String?, public val sizeBytes: Long?, - private val byteReader: suspend () -> ByteArray + private val sourceOpener: suspend () -> Source ) { - public suspend fun readBytes(): ByteArray = byteReader() + public suspend fun openSource(): Source = sourceOpener() + + public suspend fun readBytes(): ByteArray { + val source = openSource() + return try { + source.readByteArray() + } finally { + source.close() + } + } + + public suspend fun copyTo(sink: Sink): Long { + val source = openSource() + return try { + source.transferTo(sink) + } finally { + source.close() + } + } + + public companion object { + public fun inMemory( + bytes: ByteArray, + localPath: String? = null + ): DataSourceStoredArtifact { + val owned = bytes.copyOf() + return DataSourceStoredArtifact( + localPath = localPath, + sizeBytes = owned.size.toLong(), + sourceOpener = { owned.toDataSourceSource() } + ) + } + + public fun inMemoryFrom( + source: Source, + localPath: String? = null, + sizeBytes: Long? = null + ): DataSourceStoredArtifact { + val buffer = Buffer() + val copied = source.transferTo(buffer) + return DataSourceStoredArtifact( + localPath = localPath, + sizeBytes = sizeBytes ?: copied, + sourceOpener = { buffer.copy() } + ) + } + } +} + +/** + * Filesystem-backed artifact store built on kotlinx-io so the cache policy + * remains reusable across KMP targets that expose [SystemFileSystem]. + */ +public class FileSystemDataSourceArtifactStore( + private val cacheDir: Path, + private val fileSystem: FileSystem = SystemFileSystem +) : DataSourceArtifactStore { + override suspend fun readLocal(path: String): DataSourceStoredArtifact? { + val localPath = Path(path) + val metadata = fileSystem.metadataOrNull(localPath) ?: return null + if (!metadata.isRegularFile) { + throw DataSourceException("Data source path is not a file: $path") + } + val resolved = fileSystem.resolve(localPath) + return resolved.toStoredArtifact(metadata.size) + } + + override suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? { + val target = Path(cacheDir, cacheKey) + val metadata = fileSystem.metadataOrNull(target) ?: return null + return if (metadata.isRegularFile) target.toStoredArtifact(metadata.size) else null + } + + override suspend fun writeCache( + cacheKey: String, + source: Source, + sizeBytes: Long?, + validate: suspend (DataSourceStoredArtifact) -> Unit + ): DataSourceStoredArtifact { + fileSystem.createDirectories(cacheDir) + val target = Path(cacheDir, cacheKey) + val temp = Path(cacheDir, "$cacheKey.tmp") + + val sink = fileSystem.sink(temp).buffered() + try { + source.transferTo(sink) + sink.flush() + } finally { + sink.close() + } + + val tempMetadata = fileSystem.metadataOrNull(temp) + val tempArtifact = temp.toStoredArtifact(tempMetadata?.size ?: sizeBytes) + try { + validate(tempArtifact) + } catch (throwable: Throwable) { + fileSystem.delete(temp, mustExist = false) + throw throwable + } + + fileSystem.atomicMove(temp, target) + val metadata = fileSystem.metadataOrNull(target) + return target.toStoredArtifact(metadata?.size ?: sizeBytes) + } + + private fun Path.toStoredArtifact(sizeBytes: Long?): DataSourceStoredArtifact { + val path = this + return DataSourceStoredArtifact( + localPath = path.toString(), + sizeBytes = sizeBytes, + sourceOpener = { fileSystem.source(path).buffered() } + ) + } } /** @@ -49,7 +201,7 @@ public class DataSourceStoredArtifact( * injected so this policy can be reused by each KMP target. */ public class DefaultDataSourceResolver( - private val store: DataSourceByteStore, + private val store: DataSourceArtifactStore, private val fetcher: RemoteDataSourceFetcher, private val checksum: DataSourceChecksum, private val headerProvider: DataSourceHeaderProvider = DataSourceHeaderProvider { request, _ -> @@ -71,7 +223,7 @@ public class DefaultDataSourceResolver( val path = parsed.localPath ?: throw DataSourceException("File source has no local path: ${request.uri}") val stored = store.readLocal(path) ?: throw DataSourceException("Data source file not found: $path") - request.expectedSha256?.let { verifySha256(stored.readBytes(), it, request.uri) } + request.expectedSha256?.let { verifySha256(stored, it, request.uri) } return stored.toArtifact(request, parsed, cacheHit = true) } @@ -83,7 +235,7 @@ public class DefaultDataSourceResolver( if (canUseCache) { val cached = store.readCache(parsed.cacheKey) if (cached != null) { - request.expectedSha256?.let { verifySha256(cached.readBytes(), it, request.uri) } + request.expectedSha256?.let { verifySha256(cached, it, request.uri) } return cached.toArtifact(request, parsed, cacheHit = true) } } @@ -92,22 +244,35 @@ public class DefaultDataSourceResolver( throw DataSourceException("No cached artifact available for offline source: ${request.uri}") } - val bytes = fetcher.fetch(parsed.transportUri, headerProvider.headers(request, parsed)) - request.expectedSha256?.let { verifySha256(bytes, it, request.uri) } + val remote = fetcher.fetch(parsed.transportUri, headerProvider.headers(request, parsed)) if (request.cachePolicy == CachePolicy.Bypass) { - return DataSourceArtifact( - request = request, - parsedUri = parsed, - filename = parsed.filename, - localPath = null, - sizeBytes = bytes.size.toLong(), - cacheHit = false, - byteReader = { bytes } - ) + val stored = try { + DataSourceStoredArtifact.inMemoryFrom(remote.source, sizeBytes = remote.sizeBytes) + } finally { + remote.source.close() + } + request.expectedSha256?.let { verifySha256(stored, it, request.uri) } + return stored.toArtifact(request, parsed, cacheHit = false) } - val stored = store.writeCache(parsed.cacheKey, bytes) + val expectedSha256 = request.expectedSha256 + val hash = expectedSha256?.let { checksum.newSha256() } + val source = hash?.let { HashingRawSource(remote.source, it).buffered() } ?: remote.source + val stored = try { + store.writeCache( + cacheKey = parsed.cacheKey, + source = source, + sizeBytes = remote.sizeBytes, + validate = { + if (expectedSha256 != null && hash != null) { + verifySha256Hex(hash.hex(), expectedSha256, request.uri) + } + } + ) + } finally { + source.close() + } return stored.toArtifact(request, parsed, cacheHit = false) } @@ -123,16 +288,66 @@ public class DefaultDataSourceResolver( localPath = localPath, sizeBytes = sizeBytes, cacheHit = cacheHit, - byteReader = { readBytes() } + sourceOpener = { openSource() } ) } - private fun verifySha256(bytes: ByteArray, expected: String, uri: String) { - val actual = checksum.sha256Hex(bytes) + private suspend fun verifySha256(artifact: DataSourceStoredArtifact, expected: String, uri: String) { + val actual = artifact.sha256Hex() + verifySha256Hex(actual, expected, uri) + } + + private fun verifySha256Hex(actual: String, expected: String, uri: String) { if (!actual.equals(expected, ignoreCase = true)) { throw DataSourceException( "SHA-256 mismatch for $uri: expected ${expected.lowercase()}, actual $actual" ) } } + + private suspend fun DataSourceStoredArtifact.sha256Hex(): String { + val hash = checksum.newSha256() + val buffer = ByteArray(STREAM_BUFFER_SIZE) + val source = openSource() + try { + while (true) { + val read = source.readAtMostTo(buffer) + if (read == -1) break + hash.update(buffer, endIndex = read) + } + } finally { + source.close() + } + return hash.hex() + } + + private companion object { + private const val STREAM_BUFFER_SIZE = 8 * 1024 + } +} + +private class HashingRawSource( + private val source: Source, + private val hash: DataSourceHash +) : RawSource { + override fun readAtMostTo(sink: Buffer, byteCount: Long): Long { + val start = sink.size + val read = source.readAtMostTo(sink, byteCount) + if (read > 0) { + val copied = Buffer() + sink.copyTo(copied, startIndex = start, endIndex = start + read) + hash.update(copied.readByteArray()) + } + return read + } + + override fun close() { + source.close() + } +} + +private fun ByteArray.toDataSourceSource(): Source { + val buffer = Buffer() + buffer.write(this) + return buffer } diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt index 1b864465..01e4924e 100644 --- a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt @@ -1,6 +1,9 @@ package sk.ainet.data.source import kotlinx.coroutines.test.runTest +import kotlinx.io.Buffer +import kotlinx.io.Source +import kotlinx.io.readByteArray import kotlin.test.Test import kotlin.test.assertContentEquals import kotlin.test.assertEquals @@ -23,6 +26,10 @@ class DefaultDataSourceResolverTest { assertTrue(artifact.cacheHit) assertEquals(0, fetcher.calls) assertContentEquals("local".encodeToByteArray(), artifact.readBytes()) + + val copied = Buffer() + assertEquals(5, artifact.copyTo(copied)) + assertContentEquals("local".encodeToByteArray(), copied.readByteArray()) } @Test @@ -66,8 +73,9 @@ class DefaultDataSourceResolverTest { @Test fun verifiesChecksumsInCommonCore() = runTest { + val store = MemoryDataSourceByteStore() val resolver = DefaultDataSourceResolver( - store = MemoryDataSourceByteStore(), + store = store, fetcher = RecordingFetcher("payload".encodeToByteArray()), checksum = TestChecksum ) @@ -80,6 +88,7 @@ class DefaultDataSourceResolverTest { ) ) } + assertEquals(0, store.cacheWrites) } @Test @@ -113,8 +122,8 @@ class DefaultDataSourceResolverTest { private class MemoryDataSourceByteStore( private val localArtifacts: Map = emptyMap() -) : DataSourceByteStore { - private val cacheArtifacts = mutableMapOf() +) : DataSourceArtifactStore { + private val cacheArtifacts = mutableMapOf() var cacheWrites: Int = 0 private set @@ -124,22 +133,29 @@ private class MemoryDataSourceByteStore( } override suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? { - return cacheArtifacts[cacheKey]?.storedAt("/cache/$cacheKey") + return cacheArtifacts[cacheKey] } - override suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact { + override suspend fun writeCache( + cacheKey: String, + source: Source, + sizeBytes: Long?, + validate: suspend (DataSourceStoredArtifact) -> Unit + ): DataSourceStoredArtifact { + val stored = DataSourceStoredArtifact.inMemoryFrom( + source = source, + localPath = "/cache/$cacheKey", + sizeBytes = sizeBytes + ) + validate(stored) cacheWrites++ - cacheArtifacts[cacheKey] = bytes - return bytes.storedAt("/cache/$cacheKey") + cacheArtifacts[cacheKey] = stored + return stored } private fun ByteArray.storedAt(path: String): DataSourceStoredArtifact { val bytes = copyOf() - return DataSourceStoredArtifact( - localPath = path, - sizeBytes = bytes.size.toLong(), - byteReader = { bytes.copyOf() } - ) + return DataSourceStoredArtifact.inMemory(bytes, localPath = path) } } @@ -152,13 +168,23 @@ private class RecordingFetcher( var lastHeaders: Map = emptyMap() private set - override suspend fun fetch(uri: String, headers: Map): ByteArray { + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { calls++ lastHeaders = headers - return bytes.copyOf() + return DataSourceRemoteContent.fromBytes(bytes.copyOf()) } } private object TestChecksum : DataSourceChecksum { - override fun sha256Hex(bytes: ByteArray): String = "sha:${bytes.decodeToString()}" + override fun newSha256(): DataSourceHash = TestHash() +} + +private class TestHash : DataSourceHash { + private val text = StringBuilder() + + override fun update(bytes: ByteArray, startIndex: Int, endIndex: Int) { + text.append(bytes.copyOfRange(startIndex, endIndex).decodeToString()) + } + + override fun hex(): String = "sha:$text" } diff --git a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt index 0dec564d..2a148bf4 100644 --- a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt +++ b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt @@ -1,11 +1,15 @@ package sk.ainet.data.source import io.ktor.client.HttpClient -import io.ktor.client.call.body import io.ktor.client.engine.cio.CIO import io.ktor.client.plugins.HttpTimeout import io.ktor.client.request.get import io.ktor.client.request.header +import io.ktor.client.statement.bodyAsChannel +import io.ktor.http.HttpHeaders +import io.ktor.utils.io.asSource +import kotlinx.io.buffered +import kotlinx.io.files.Path import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext import java.io.File @@ -24,10 +28,14 @@ public class KtorRemoteDataSourceFetcher( } } ) : RemoteDataSourceFetcher, AutoCloseable { - override suspend fun fetch(uri: String, headers: Map): ByteArray { - return client.get(uri) { + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { + val response = client.get(uri) { headers.forEach { (name, value) -> header(name, value) } - }.body() + } + return DataSourceRemoteContent( + source = response.bodyAsChannel().asSource().buffered(), + sizeBytes = response.headers[HttpHeaders.ContentLength]?.toLongOrNull() + ) } override fun close() { @@ -43,7 +51,7 @@ public class JvmDataSourceResolver( fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher() ) : DataSourceResolver { private val delegate = DefaultDataSourceResolver( - store = JvmFileDataSourceByteStore(cacheDir), + store = FileSystemDataSourceArtifactStore(Path(cacheDir.absolutePath)), fetcher = fetcher, checksum = JvmSha256DataSourceChecksum, headerProvider = JvmHuggingFaceHeaderProvider @@ -62,44 +70,6 @@ public class JvmDataSourceResolver( } } -internal class JvmFileDataSourceByteStore( - private val cacheDir: File -) : DataSourceByteStore { - override suspend fun readLocal(path: String): DataSourceStoredArtifact? { - val file = File(path) - if (!file.exists()) return null - if (!file.isFile) { - throw DataSourceException("Data source path is not a file: ${file.absolutePath}") - } - return file.toStoredArtifact() - } - - override suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? { - val target = File(cacheDir, cacheKey) - return if (target.exists() && target.isFile) target.toStoredArtifact() else null - } - - override suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact { - cacheDir.mkdirs() - val target = File(cacheDir, cacheKey) - val temp = File(cacheDir, "$cacheKey.tmp") - temp.writeBytes(bytes) - if (!temp.renameTo(target)) { - temp.copyTo(target, overwrite = true) - temp.delete() - } - return target.toStoredArtifact() - } - - private fun File.toStoredArtifact(): DataSourceStoredArtifact { - return DataSourceStoredArtifact( - localPath = absolutePath, - sizeBytes = length(), - byteReader = { readBytes() } - ) - } -} - internal object JvmHuggingFaceHeaderProvider : DataSourceHeaderProvider { override fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map { if (parsedUri.provider != DataSourceProvider.HuggingFace) return request.headers @@ -113,9 +83,19 @@ internal object JvmHuggingFaceHeaderProvider : DataSourceHeaderProvider { } internal object JvmSha256DataSourceChecksum : DataSourceChecksum { - override fun sha256Hex(bytes: ByteArray): String { - return MessageDigest.getInstance("SHA-256") - .digest(bytes) + override fun newSha256(): DataSourceHash = JvmSha256DataSourceHash() +} + +private class JvmSha256DataSourceHash : DataSourceHash { + private val digest = MessageDigest.getInstance("SHA-256") + + override fun update(bytes: ByteArray, startIndex: Int, endIndex: Int) { + digest.update(bytes, startIndex, endIndex - startIndex) + } + + override fun hex(): String { + return digest + .digest() .joinToString("") { byte -> "%02x".format(byte) } } } diff --git a/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt index 5b44523e..f0102d38 100644 --- a/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt +++ b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt @@ -22,7 +22,7 @@ class JvmDataSourceResolverTest { val artifact = resolver.resolve(DataSourceRequest(file.toURI().toString())) assertEquals("sample.txt", artifact.filename) - assertEquals(file.absolutePath, artifact.localPath) + assertEquals(file.canonicalPath, artifact.localPath) assertTrue(artifact.cacheHit) assertContentEquals("hello".encodeToByteArray(), artifact.readBytes()) } finally { @@ -142,9 +142,9 @@ private class FakeFetcher( var calls: Int = 0 private set - override suspend fun fetch(uri: String, headers: Map): ByteArray { + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { calls++ - return bytes + return DataSourceRemoteContent.fromBytes(bytes) } } @@ -154,9 +154,9 @@ private class QueueFetcher( var calls: Int = 0 private set - override suspend fun fetch(uri: String, headers: Map): ByteArray { + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { val index = calls.coerceAtMost(responses.lastIndex) calls++ - return responses[index] + return DataSourceRemoteContent.fromBytes(responses[index]) } } From 35ad83350f47f8e1373f3b325389b9c20d8e4193 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 19:31:58 +0200 Subject: [PATCH 11/18] data: parameterize Hugging Face auth --- .../data-sources-getting-started.adoc | 38 +++++++++-- .../sk/ainet/data/cifar10/CIFAR10Data.kt | 5 +- .../common/DatasetHuggingFaceTokenProvider.kt | 9 +++ .../data/fashionmnist/FashionMNISTData.kt | 5 +- .../kotlin/sk/ainet/data/mnist/MNISTData.kt | 5 +- .../sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt | 7 +- .../data/common/JvmDatasetSourceReader.kt | 13 +++- .../fashionmnist/FashionMNISTLoaderJvm.kt | 7 +- .../sk/ainet/data/mnist/MNISTLoaderJvm.kt | 7 +- .../sk/ainet/data/source/DataSourceModels.kt | 30 ++++++++- .../data/source/DefaultDataSourceResolver.kt | 34 +++++++++- .../source/DefaultDataSourceResolverTest.kt | 67 +++++++++++++++++++ .../data/source/JvmDataSourceResolver.kt | 30 ++++++--- .../data/source/JvmDataSourceResolverTest.kt | 49 ++++++++++++++ 14 files changed, 277 insertions(+), 29 deletions(-) create mode 100644 skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetHuggingFaceTokenProvider.kt diff --git a/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc index 34eb72a7..4557e875 100644 --- a/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc +++ b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc @@ -41,16 +41,20 @@ dependencies { `JvmDataSourceResolver` materializes remote artifacts into a cache and returns a `DataSourceArtifact` that opens a `kotlinx.io.Source`. Public Hugging Face -files do not need credentials. Private files can use an `Authorization` header, -or the JVM resolver will read `HF_TOKEN` / `HUGGING_FACE_HUB_TOKEN` from the -environment when the URI provider is Hugging Face. +files do not need credentials. Private files should pass an explicit +`DataSourceAuthToken` on the request or resolver. Existing `Authorization` +headers still take precedence. On JVM, the resolver can also read `HF_TOKEN` / +`HUGGING_FACE_HUB_TOKEN` from the environment as an opt-in convenience fallback. [source,kotlin] ---- +import sk.ainet.data.source.DataSourceAuthToken import sk.ainet.data.source.DataSourceRequest import sk.ainet.data.source.JvmDataSourceResolver -val resolver = JvmDataSourceResolver() +val resolver = JvmDataSourceResolver( + huggingFaceToken = DataSourceAuthToken.from("hf_...") +) val artifact = resolver.resolve( DataSourceRequest( uri = "hf+https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json" @@ -71,6 +75,28 @@ try { val bytes = artifact.readBytes() ---- +For per-request credentials, pass the token directly on `DataSourceRequest`. +This is useful when one resolver works with more than one private repository: + +[source,kotlin] +---- +val privateArtifact = resolver.resolve( + DataSourceRequest( + uri = "hf://datasets/your-org/private-dataset@main/data/train.bin", + huggingFaceToken = DataSourceAuthToken.from("hf_...") + ) +) +---- + +To opt into JVM environment fallback: + +[source,kotlin] +---- +val resolver = JvmDataSourceResolver( + useEnvironmentHuggingFaceToken = true +) +---- + === Use sources with built-in loaders MNIST and Fashion-MNIST expose per-file URI overrides. CIFAR-10 exposes an @@ -82,10 +108,12 @@ locations, so existing code keeps working. import sk.ainet.data.mnist.MNIST import sk.ainet.data.mnist.MNISTLoaderConfig +val token = "hf_..." val train = MNIST.loadTrain( MNISTLoaderConfig( trainImagesUri = "file:///datasets/mnist/train-images-idx3-ubyte", - trainLabelsUri = "hf+https://huggingface.co/your-org/mnist-idx/resolve/main/train-labels-idx1-ubyte.gz" + trainLabelsUri = "hf+https://huggingface.co/your-org/mnist-idx/resolve/main/train-labels-idx1-ubyte.gz", + huggingFaceTokenProvider = { token } ) ) diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt index 9344e150..1c63b86f 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt @@ -6,6 +6,7 @@ import sk.ainet.context.DefaultDataExecutionContext import sk.ainet.context.ExecutionContext import sk.ainet.data.DataBatch import sk.ainet.data.Dataset +import sk.ainet.data.common.DatasetHuggingFaceTokenProvider import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.types.DType @@ -145,7 +146,9 @@ public data class CIFAR10Dataset( public data class CIFAR10LoaderConfig( val cacheDir: String = "cifar10-data", val useCache: Boolean = true, - val archiveUri: String = CIFAR10Constants.DOWNLOAD_URL + val archiveUri: String = CIFAR10Constants.DOWNLOAD_URL, + val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null, + val useEnvironmentHuggingFaceToken: Boolean = false ) /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetHuggingFaceTokenProvider.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetHuggingFaceTokenProvider.kt new file mode 100644 index 00000000..1efeacd1 --- /dev/null +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetHuggingFaceTokenProvider.kt @@ -0,0 +1,9 @@ +package sk.ainet.data.common + +/** + * Supplies a Hugging Face token for built-in dataset loaders when their source + * URIs point at private Hugging Face artifacts. + */ +public fun interface DatasetHuggingFaceTokenProvider { + public fun token(): String? +} diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt index bdc0e99e..c9286506 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt @@ -6,6 +6,7 @@ import sk.ainet.context.DefaultDataExecutionContext import sk.ainet.context.ExecutionContext import sk.ainet.data.DataBatch import sk.ainet.data.Dataset +import sk.ainet.data.common.DatasetHuggingFaceTokenProvider import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.types.DType @@ -150,7 +151,9 @@ public data class FashionMNISTLoaderConfig( val trainImagesUri: String = FashionMNISTConstants.TRAIN_IMAGES_URL, val trainLabelsUri: String = FashionMNISTConstants.TRAIN_LABELS_URL, val testImagesUri: String = FashionMNISTConstants.TEST_IMAGES_URL, - val testLabelsUri: String = FashionMNISTConstants.TEST_LABELS_URL + val testLabelsUri: String = FashionMNISTConstants.TEST_LABELS_URL, + val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null, + val useEnvironmentHuggingFaceToken: Boolean = false ) /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt index c4bfa76a..ae6881cf 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt @@ -6,6 +6,7 @@ import sk.ainet.context.DefaultDataExecutionContext import sk.ainet.context.ExecutionContext import sk.ainet.data.DataBatch import sk.ainet.data.Dataset +import sk.ainet.data.common.DatasetHuggingFaceTokenProvider import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.types.DType @@ -128,7 +129,9 @@ public data class MNISTLoaderConfig( val trainImagesUri: String = MNISTConstants.TRAIN_IMAGES_URL, val trainLabelsUri: String = MNISTConstants.TRAIN_LABELS_URL, val testImagesUri: String = MNISTConstants.TEST_IMAGES_URL, - val testLabelsUri: String = MNISTConstants.TEST_LABELS_URL + val testLabelsUri: String = MNISTConstants.TEST_LABELS_URL, + val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null, + val useEnvironmentHuggingFaceToken: Boolean = false ) /** diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt index 10fc06a3..5ecf29ee 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt @@ -15,7 +15,12 @@ import java.io.FileOutputStream * @property config The configuration for the CIFAR-10 loader. */ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon(config) { - private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache) + private val sources = JvmDatasetSourceReader( + cacheDir = config.cacheDir, + useCache = config.useCache, + huggingFaceTokenProvider = config.huggingFaceTokenProvider, + useEnvironmentHuggingFaceToken = config.useEnvironmentHuggingFaceToken + ) /** * Downloads the CIFAR-10 archive and extracts the specified batch file. diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt index 29ce6bbf..785910fd 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt @@ -1,6 +1,7 @@ package sk.ainet.data.common import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceAuthToken import sk.ainet.data.source.DataSourceRequest import sk.ainet.data.source.JvmDataSourceResolver import java.io.ByteArrayInputStream @@ -9,16 +10,22 @@ import java.util.zip.GZIPInputStream internal class JvmDatasetSourceReader( cacheDir: String, - useCache: Boolean + useCache: Boolean, + private val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null, + useEnvironmentHuggingFaceToken: Boolean = false ) { - private val resolver = JvmDataSourceResolver(File(cacheDir, "sources")) + private val resolver = JvmDataSourceResolver( + cacheDir = File(cacheDir, "sources"), + useEnvironmentHuggingFaceToken = useEnvironmentHuggingFaceToken + ) private val cachePolicy = if (useCache) CachePolicy.Use else CachePolicy.Refresh suspend fun read(uri: String): ByteArray { val artifact = resolver.resolve( DataSourceRequest( uri = uri, - cachePolicy = cachePolicy + cachePolicy = cachePolicy, + huggingFaceToken = DataSourceAuthToken.fromOrNull(huggingFaceTokenProvider?.token()) ) ) return artifact.readBytes() diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt index 9b100db9..27c598e6 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt @@ -8,7 +8,12 @@ import sk.ainet.data.common.JvmDatasetSourceReader * @property config The configuration for the Fashion-MNIST loader. */ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMNISTLoaderCommon(config) { - private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache) + private val sources = JvmDatasetSourceReader( + cacheDir = config.cacheDir, + useCache = config.useCache, + huggingFaceTokenProvider = config.huggingFaceTokenProvider, + useEnvironmentHuggingFaceToken = config.useEnvironmentHuggingFaceToken + ) /** * Resolves, caches, and decompresses a file when needed. diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt index e5466b4f..0457d66c 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt @@ -8,7 +8,12 @@ import sk.ainet.data.common.JvmDatasetSourceReader * @property config The configuration for the MNIST loader. */ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(config) { - private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache) + private val sources = JvmDatasetSourceReader( + cacheDir = config.cacheDir, + useCache = config.useCache, + huggingFaceTokenProvider = config.huggingFaceTokenProvider, + useEnvironmentHuggingFaceToken = config.useEnvironmentHuggingFaceToken + ) /** * Resolves, caches, and decompresses a file when needed. diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt index 0e34245d..12077c0b 100644 --- a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt @@ -50,6 +50,33 @@ public data class HuggingFaceLocation( public val path: String? ) +/** + * Authentication token for provider-specific data source requests. + * + * The raw value is intentionally hidden from [toString] output so tokens are + * not leaked when requests or configs are logged. + */ +public class DataSourceAuthToken private constructor( + private val value: String +) { + override fun toString(): String = "DataSourceAuthToken(***)" + + internal fun authorizationHeaderValue(): String = "Bearer $value" + + public companion object { + public fun from(value: String): DataSourceAuthToken { + val normalized = value.trim() + require(normalized.isNotEmpty()) { "Data source auth token cannot be blank" } + return DataSourceAuthToken(normalized) + } + + public fun fromOrNull(value: String?): DataSourceAuthToken? { + val normalized = value?.trim()?.takeIf { it.isNotEmpty() } ?: return null + return DataSourceAuthToken(normalized) + } + } +} + /** * A normalized, provider-aware source URI. */ @@ -70,7 +97,8 @@ public data class DataSourceRequest( public val uri: String, public val cachePolicy: CachePolicy = CachePolicy.Use, public val expectedSha256: String? = null, - public val headers: Map = emptyMap() + public val headers: Map = emptyMap(), + public val huggingFaceToken: DataSourceAuthToken? = null ) /** diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt index b7ee217f..484de79d 100644 --- a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt @@ -42,6 +42,32 @@ public fun interface DataSourceHeaderProvider { public fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map } +/** + * Supplies a Hugging Face token when a request does not carry one directly. + */ +public fun interface HuggingFaceTokenProvider { + public fun token(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): DataSourceAuthToken? +} + +/** + * Adds Hugging Face bearer auth from explicit request or resolver-level token + * configuration while leaving generic HTTP requests unchanged. + */ +public class HuggingFaceTokenHeaderProvider( + private val tokenProvider: HuggingFaceTokenProvider = HuggingFaceTokenProvider { _, _ -> null } +) : DataSourceHeaderProvider { + override fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map { + if (parsedUri.provider != DataSourceProvider.HuggingFace) return request.headers + if (request.headers.hasAuthorizationHeader()) return request.headers + val token = request.huggingFaceToken ?: tokenProvider.token(request, parsedUri) ?: return request.headers + return request.headers + (AUTHORIZATION_HEADER to token.authorizationHeaderValue()) + } + + private companion object { + private const val AUTHORIZATION_HEADER = "Authorization" + } +} + /** * Computes checksums for integrity verification without tying resolver policy * to a concrete platform crypto API. @@ -204,9 +230,7 @@ public class DefaultDataSourceResolver( private val store: DataSourceArtifactStore, private val fetcher: RemoteDataSourceFetcher, private val checksum: DataSourceChecksum, - private val headerProvider: DataSourceHeaderProvider = DataSourceHeaderProvider { request, _ -> - request.headers - } + private val headerProvider: DataSourceHeaderProvider = HuggingFaceTokenHeaderProvider() ) : DataSourceResolver { override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact { val parsed = DataSourceUriParser.parse(request.uri) @@ -326,6 +350,10 @@ public class DefaultDataSourceResolver( } } +private fun Map.hasAuthorizationHeader(): Boolean { + return keys.any { it.equals("Authorization", ignoreCase = true) } +} + private class HashingRawSource( private val source: Source, private val hash: DataSourceHash diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt index 01e4924e..e3803588 100644 --- a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt @@ -118,6 +118,73 @@ class DefaultDataSourceResolverTest { fetcher.lastHeaders ) } + + @Test + fun addsHuggingFaceTokenFromRequest() = runTest { + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = fetcher, + checksum = TestChecksum + ) + val token = DataSourceAuthToken.from("hf_request") + + resolver.resolve( + DataSourceRequest( + uri = "hf://datasets/org/repo@main/file.bin", + headers = mapOf("Accept" to "application/octet-stream"), + huggingFaceToken = token + ) + ) + + assertEquals( + mapOf( + "Accept" to "application/octet-stream", + "Authorization" to "Bearer hf_request" + ), + fetcher.lastHeaders + ) + assertEquals("DataSourceAuthToken(***)", token.toString()) + } + + @Test + fun keepsExistingAuthorizationHeaderOverHuggingFaceToken() = runTest { + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = fetcher, + checksum = TestChecksum + ) + + resolver.resolve( + DataSourceRequest( + uri = "hf://org/repo@main/file.bin", + headers = mapOf("authorization" to "Bearer explicit"), + huggingFaceToken = DataSourceAuthToken.from("hf_request") + ) + ) + + assertEquals(mapOf("authorization" to "Bearer explicit"), fetcher.lastHeaders) + } + + @Test + fun doesNotAddHuggingFaceTokenToGenericHttp() = runTest { + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = fetcher, + checksum = TestChecksum + ) + + resolver.resolve( + DataSourceRequest( + uri = "https://example.test/data.bin", + huggingFaceToken = DataSourceAuthToken.from("hf_request") + ) + ) + + assertEquals(emptyMap(), fetcher.lastHeaders) + } } private class MemoryDataSourceByteStore( diff --git a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt index 2a148bf4..fd7c9a88 100644 --- a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt +++ b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt @@ -48,13 +48,20 @@ public class KtorRemoteDataSourceFetcher( */ public class JvmDataSourceResolver( cacheDir: File = defaultCacheDir(), - fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher() + fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher(), + huggingFaceToken: DataSourceAuthToken? = null, + useEnvironmentHuggingFaceToken: Boolean = false ) : DataSourceResolver { private val delegate = DefaultDataSourceResolver( store = FileSystemDataSourceArtifactStore(Path(cacheDir.absolutePath)), fetcher = fetcher, checksum = JvmSha256DataSourceChecksum, - headerProvider = JvmHuggingFaceHeaderProvider + headerProvider = HuggingFaceTokenHeaderProvider( + JvmHuggingFaceTokenProvider( + configuredToken = huggingFaceToken, + useEnvironmentToken = useEnvironmentHuggingFaceToken + ) + ) ) override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact = withContext(Dispatchers.IO) { @@ -70,15 +77,16 @@ public class JvmDataSourceResolver( } } -internal object JvmHuggingFaceHeaderProvider : DataSourceHeaderProvider { - override fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map { - if (parsedUri.provider != DataSourceProvider.HuggingFace) return request.headers - if (request.headers.keys.any { it.equals("Authorization", ignoreCase = true) }) return request.headers - val token = System.getenv("HF_TOKEN") - ?.takeIf { it.isNotBlank() } - ?: System.getenv("HUGGING_FACE_HUB_TOKEN")?.takeIf { it.isNotBlank() } - ?: return request.headers - return request.headers + ("Authorization" to "Bearer $token") +internal class JvmHuggingFaceTokenProvider( + private val configuredToken: DataSourceAuthToken?, + private val useEnvironmentToken: Boolean +) : HuggingFaceTokenProvider { + override fun token(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): DataSourceAuthToken? { + if (parsedUri.provider != DataSourceProvider.HuggingFace) return null + configuredToken?.let { return it } + if (!useEnvironmentToken) return null + return DataSourceAuthToken.fromOrNull(System.getenv("HF_TOKEN")) + ?: DataSourceAuthToken.fromOrNull(System.getenv("HUGGING_FACE_HUB_TOKEN")) } } diff --git a/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt index f0102d38..4f997cd4 100644 --- a/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt +++ b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt @@ -134,6 +134,51 @@ class JvmDataSourceResolverTest { root.deleteRecursively() } } + + @Test + fun sendsConfiguredHuggingFaceToken() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = FakeFetcher("payload".encodeToByteArray()) + val resolver = JvmDataSourceResolver( + cacheDir = root.resolve("cache"), + fetcher = fetcher, + huggingFaceToken = DataSourceAuthToken.from("hf_configured"), + useEnvironmentHuggingFaceToken = false + ) + + resolver.resolve(DataSourceRequest("hf://org/repo@main/file.bin")) + + assertEquals(mapOf("Authorization" to "Bearer hf_configured"), fetcher.lastHeaders) + } finally { + root.deleteRecursively() + } + } + + @Test + fun requestHuggingFaceTokenOverridesConfiguredToken() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = FakeFetcher("payload".encodeToByteArray()) + val resolver = JvmDataSourceResolver( + cacheDir = root.resolve("cache"), + fetcher = fetcher, + huggingFaceToken = DataSourceAuthToken.from("hf_configured"), + useEnvironmentHuggingFaceToken = false + ) + + resolver.resolve( + DataSourceRequest( + uri = "hf://org/repo@main/file.bin", + huggingFaceToken = DataSourceAuthToken.from("hf_request") + ) + ) + + assertEquals(mapOf("Authorization" to "Bearer hf_request"), fetcher.lastHeaders) + } finally { + root.deleteRecursively() + } + } } private class FakeFetcher( @@ -142,8 +187,12 @@ private class FakeFetcher( var calls: Int = 0 private set + var lastHeaders: Map = emptyMap() + private set + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { calls++ + lastHeaders = headers return DataSourceRemoteContent.fromBytes(bytes) } } From de57f8735a929fee9ae0feb5785ce5d94477eabb Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 20:33:03 +0200 Subject: [PATCH 12/18] data: support indexed simple batches --- .../sk/ainet/data/cifar10/CIFAR10Data.kt | 26 ++++++++++++++----- .../data/fashionmnist/FashionMNISTData.kt | 26 ++++++++++++++----- .../kotlin/sk/ainet/data/mnist/MNISTData.kt | 26 ++++++++++++++----- .../io/data/cifar10/CIFAR10LoaderTest.kt | 14 ++++++++++ .../fashionmnist/FashionMNISTLoaderTest.kt | 14 ++++++++++ .../sk/ainet/io/data/mnist/MNISTLoaderTest.kt | 14 ++++++++++ 6 files changed, 99 insertions(+), 21 deletions(-) diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt index 1c63b86f..134722df 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt @@ -102,10 +102,23 @@ public data class CIFAR10Dataset( */ override fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch { val actualLen = min(batchLength, xSize - batchStart) - val batchImages = images.subList(batchStart, batchStart + actualLen) + val batchIndices = IntArray(actualLen) { offset -> batchStart + offset } + return createDataBatchForIndices(batchIndices) + } + + /** + * Creates a DataBatch from arbitrary sample indices. This keeps shuffled and + * filtered dataset views compatible with tensor batching. + */ + override fun createIndexedDataBatch(indices: IntArray): DataBatch = + createDataBatchForIndices(indices) + + @Suppress("UNCHECKED_CAST") + private fun createDataBatchForIndices(indices: IntArray): DataBatch { + val batchImages = indices.map { images[it] } // Concatenate raw image bytes (no normalization) for memory efficiency - val xData = ByteArray(actualLen * CIFAR10Constants.IMAGE_BYTES) + val xData = ByteArray(indices.size * CIFAR10Constants.IMAGE_BYTES) var offset = 0 for (sample in batchImages) { val bytes = sample.image @@ -114,19 +127,18 @@ public data class CIFAR10Dataset( } // Shape as [batch, 3, 32, 32] (channel-first) - val xShape = Shape(actualLen, CIFAR10Constants.NUM_CHANNELS, CIFAR10Constants.IMAGE_SIZE, CIFAR10Constants.IMAGE_SIZE) + val xShape = Shape(indices.size, CIFAR10Constants.NUM_CHANNELS, CIFAR10Constants.IMAGE_SIZE, CIFAR10Constants.IMAGE_SIZE) val xTensor: Tensor = executionContext.fromByteArray(xShape, Int8::class, xData) // Labels as bytes (memory-efficient) - val yData = ByteArray(actualLen) { idx -> batchImages[idx].label } - val yShape = Shape(actualLen) + val yData = ByteArray(indices.size) { idx -> batchImages[idx].label } + val yShape = Shape(indices.size) val yTensor: Tensor = executionContext.fromByteArray(yShape, Int8::class, yData) // DataBatch expects array of input tensors; we provide single input val xArray: Array> = arrayOf(xTensor) - @Suppress("UNCHECKED_CAST") - return DataBatch(xArray as Array>, yTensor as Tensor) + return DataBatch(xArray as Array>, yTensor as Tensor, indices = indices.copyOf()) } /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt index c9286506..7ee50a34 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt @@ -104,10 +104,23 @@ public data class FashionMNISTDataset( */ override fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch { val actualLen = min(batchLength, xSize - batchStart) - val batchImages = images.subList(batchStart, batchStart + actualLen) + val batchIndices = IntArray(actualLen) { offset -> batchStart + offset } + return createDataBatchForIndices(batchIndices) + } + + /** + * Creates a DataBatch from arbitrary sample indices. This keeps shuffled and + * filtered dataset views compatible with tensor batching. + */ + override fun createIndexedDataBatch(indices: IntArray): DataBatch = + createDataBatchForIndices(indices) + + @Suppress("UNCHECKED_CAST") + private fun createDataBatchForIndices(indices: IntArray): DataBatch { + val batchImages = indices.map { images[it] } // Concatenate raw image bytes (no normalization) for memory efficiency - val xData = ByteArray(actualLen * FashionMNISTConstants.IMAGE_PIXELS) + val xData = ByteArray(indices.size * FashionMNISTConstants.IMAGE_PIXELS) var offset = 0 for (sample in batchImages) { val bytes = sample.image @@ -116,19 +129,18 @@ public data class FashionMNISTDataset( } // Shape as [batch, 1, 28, 28] - val xShape = Shape(actualLen, 1, FashionMNISTConstants.IMAGE_SIZE, FashionMNISTConstants.IMAGE_SIZE) + val xShape = Shape(indices.size, 1, FashionMNISTConstants.IMAGE_SIZE, FashionMNISTConstants.IMAGE_SIZE) val xTensor: Tensor = executionContext.fromByteArray(xShape, Int8::class, xData) // Labels as bytes (memory-efficient) - val yData = ByteArray(actualLen) { idx -> batchImages[idx].label } - val yShape = Shape(actualLen) + val yData = ByteArray(indices.size) { idx -> batchImages[idx].label } + val yShape = Shape(indices.size) val yTensor: Tensor = executionContext.fromByteArray(yShape, Int8::class, yData) // DataBatch expects array of input tensors; we provide single input val xArray: Array> = arrayOf(xTensor) - @Suppress("UNCHECKED_CAST") - return DataBatch(xArray as Array>, yTensor as Tensor) + return DataBatch(xArray as Array>, yTensor as Tensor, indices = indices.copyOf()) } /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt index ae6881cf..a4875219 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt @@ -82,10 +82,23 @@ public data class MNISTDataset( */ override fun createDataBatch(batchStart: Int, batchLength: Int): DataBatch { val actualLen = min(batchLength, xSize - batchStart) - val batchImages = images.subList(batchStart, batchStart + actualLen) + val batchIndices = IntArray(actualLen) { offset -> batchStart + offset } + return createDataBatchForIndices(batchIndices) + } + + /** + * Creates a DataBatch from arbitrary sample indices. This keeps shuffled and + * filtered dataset views compatible with tensor batching. + */ + override fun createIndexedDataBatch(indices: IntArray): DataBatch = + createDataBatchForIndices(indices) + + @Suppress("UNCHECKED_CAST") + private fun createDataBatchForIndices(indices: IntArray): DataBatch { + val batchImages = indices.map { images[it] } // Concatenate raw image bytes (no normalization) for memory efficiency - val xData = ByteArray(actualLen * MNISTConstants.IMAGE_PIXELS) + val xData = ByteArray(indices.size * MNISTConstants.IMAGE_PIXELS) var offset = 0 for (sample in batchImages) { val bytes = sample.image @@ -94,19 +107,18 @@ public data class MNISTDataset( } // Shape as [batch, 1, 28, 28] - val xShape = Shape(actualLen, 1, MNISTConstants.IMAGE_SIZE, MNISTConstants.IMAGE_SIZE) + val xShape = Shape(indices.size, 1, MNISTConstants.IMAGE_SIZE, MNISTConstants.IMAGE_SIZE) val xTensor: Tensor = executionContext.fromByteArray(xShape, Int8::class, xData) // Labels as bytes (memory-efficient). Keep as Int8 to satisfy DataBatch single dtype requirement - val yData = ByteArray(actualLen) { idx -> batchImages[idx].label } - val yShape = Shape(actualLen) + val yData = ByteArray(indices.size) { idx -> batchImages[idx].label } + val yShape = Shape(indices.size) val yTensor: Tensor = executionContext.fromByteArray(yShape, Int8::class, yData) // DataBatch expects array of input tensors; we provide single input val xArray: Array> = arrayOf(xTensor) - @Suppress("UNCHECKED_CAST") - return DataBatch(xArray as Array>, yTensor as Tensor) + return DataBatch(xArray as Array>, yTensor as Tensor, indices = indices.copyOf()) } /** diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt index 541516b3..9d4d363f 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt @@ -6,6 +6,7 @@ import sk.ainet.data.cifar10.CIFAR10Image import sk.ainet.data.cifar10.CIFAR10LoaderConfig import sk.ainet.data.cifar10.CIFAR10LoaderCommon import sk.ainet.data.cifar10.createCIFAR10Loader +import sk.ainet.lang.types.Int8 import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNotNull @@ -56,6 +57,19 @@ class CIFAR10LoaderTest { assertEquals(dataset.images[1], subset.images[1]) } + @Test + fun testShuffledDatasetViewCanCreateBatch() = runBlocking { + val dataset = createFakeLoader().loadTrainingData() + val shuffled = dataset.shuffle(seed = 123) + + val batch = shuffled.batchIterator(4).next() + + assertEquals(4, batch.batchSize) + assertEquals(4, batch.indices.size) + assertEquals(4, batch.x[0].shape[0]) + assertEquals(4, batch.y.shape[0]) + } + @Test fun testLoaderConfiguration() { val config = CIFAR10LoaderConfig( diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt index 30135890..2d11cb59 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt @@ -6,6 +6,7 @@ import sk.ainet.data.fashionmnist.FashionMNISTImage import sk.ainet.data.fashionmnist.FashionMNISTLoaderConfig import sk.ainet.data.fashionmnist.FashionMNISTLoaderCommon import sk.ainet.data.fashionmnist.createFashionMNISTLoader +import sk.ainet.lang.types.Int8 import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNotNull @@ -56,6 +57,19 @@ class FashionMNISTLoaderTest { assertEquals(dataset.images[1], subset.images[1]) } + @Test + fun testShuffledDatasetViewCanCreateBatch() = runBlocking { + val dataset = createFakeLoader().loadTrainingData() + val shuffled = dataset.shuffle(seed = 123) + + val batch = shuffled.batchIterator(2).next() + + assertEquals(2, batch.batchSize) + assertEquals(2, batch.indices.size) + assertEquals(2, batch.x[0].shape[0]) + assertEquals(2, batch.y.shape[0]) + } + @Test fun testLoaderConfiguration() { val config = FashionMNISTLoaderConfig( diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt index 94796645..1de5fc26 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt @@ -6,6 +6,7 @@ import sk.ainet.data.mnist.MNISTImage import sk.ainet.data.mnist.MNISTLoaderConfig import sk.ainet.data.mnist.MNISTLoaderFactory import sk.ainet.data.mnist.MNISTLoaderCommon +import sk.ainet.lang.types.Int8 import java.nio.file.Files import kotlin.test.Test import kotlin.test.assertEquals @@ -57,6 +58,19 @@ class MNISTLoaderTest { assertEquals(dataset.images[1], subset.images[1]) } + @Test + fun testShuffledDatasetViewCanCreateBatch() = runBlocking { + val dataset = createFakeLoader().loadTrainingData() + val shuffled = dataset.shuffle(seed = 123) + + val batch = shuffled.batchIterator(2).next() + + assertEquals(2, batch.batchSize) + assertEquals(2, batch.indices.size) + assertEquals(2, batch.x[0].shape[0]) + assertEquals(2, batch.y.shape[0]) + } + @Test fun testLoaderConfiguration() { val config = MNISTLoaderConfig( From 7074a6502393e71739acbced67cac54bbeb53be3 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 20:36:39 +0200 Subject: [PATCH 13/18] data: report unsupported loader targets --- ...DatasetLoaderUnsupportedTargetException.kt | 20 +++++++++++++++++++ .../sk/ainet/data/cifar10/CIFAR10LoaderIos.kt | 7 ++++++- .../fashionmnist/FashionMNISTLoaderIos.kt | 12 ++++++++--- .../sk/ainet/data/mnist/MNISTLoaderIos.kt | 14 +++++++++---- .../sk/ainet/data/cifar10/CIFAR10LoaderJs.kt | 7 ++++++- .../data/fashionmnist/FashionMNISTLoaderJs.kt | 12 ++++++++++- .../sk/ainet/data/mnist/MNISTLoaderJs.kt | 12 ++++++++++- .../ainet/data/cifar10/CIFAR10LoaderLinux.kt | 7 ++++++- .../fashionmnist/FashionMNISTLoaderLinux.kt | 7 ++++++- .../sk/ainet/data/mnist/MNISTLoaderLinux.kt | 7 ++++++- .../ainet/data/cifar10/CIFAR10LoaderMacos.kt | 7 ++++++- .../fashionmnist/FashionMNISTLoaderMacos.kt | 7 ++++++- .../sk/ainet/data/mnist/MNISTLoaderMacos.kt | 7 ++++++- .../ainet/data/cifar10/CIFAR10LoaderWasmJs.kt | 7 ++++++- .../fashionmnist/FashionMNISTLoaderWasmJs.kt | 12 ++++++++--- .../sk/ainet/data/mnist/MNISTLoaderWasmJs.kt | 12 ++++++++--- 16 files changed, 133 insertions(+), 24 deletions(-) create mode 100644 skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetLoaderUnsupportedTargetException.kt diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetLoaderUnsupportedTargetException.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetLoaderUnsupportedTargetException.kt new file mode 100644 index 00000000..70137a49 --- /dev/null +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetLoaderUnsupportedTargetException.kt @@ -0,0 +1,20 @@ +package sk.ainet.data.common + +/** + * Raised when a built-in dataset loader is compiled for a target where the + * required transport, archive, or decompression primitive is not implemented. + */ +public class DatasetLoaderUnsupportedTargetException( + public val dataset: String, + public val target: String, + reason: String +) : UnsupportedOperationException("$dataset loader is not supported on $target: $reason") + +/** Throws a typed unsupported-target exception for built-in dataset loaders. */ +public fun unsupportedDatasetLoader(dataset: String, target: String, reason: String): Nothing { + throw DatasetLoaderUnsupportedTargetException(dataset, target, reason) +} + +/** Returns true when this byte array starts with the gzip magic header. */ +public fun ByteArray.hasGzipHeader(): Boolean = + size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() diff --git a/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderIos.kt b/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderIos.kt index 88384f31..7d0a24a5 100644 --- a/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderIos.kt +++ b/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderIos.kt @@ -2,6 +2,7 @@ package sk.ainet.data.cifar10 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * iOS implementation of the CIFAR-10 loader. @@ -13,7 +14,11 @@ public class CIFAR10LoaderIos(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon override suspend fun downloadAndExtractBatch(batchFilename: String): ByteArray = withContext(Dispatchers.Default) { - error("CIFAR10LoaderIos is not fully implemented yet. Tar.gz extraction requires native implementation.") + unsupportedDatasetLoader( + dataset = "CIFAR-10", + target = "ios", + reason = "tar.gz extraction is not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderIos.kt b/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderIos.kt index ea97291e..98d96341 100644 --- a/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderIos.kt +++ b/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderIos.kt @@ -7,6 +7,8 @@ import io.ktor.client.request.get import io.ktor.client.statement.HttpResponse import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.hasGzipHeader +import sk.ainet.data.common.unsupportedDatasetLoader /** * iOS implementation of the Fashion-MNIST loader. @@ -29,9 +31,13 @@ public class FashionMNISTLoaderIos(config: FashionMNISTLoaderConfig) : FashionMN println("Downloading Fashion-MNIST file: $url") val data = downloadFile(url) - // Note: In a real implementation, we would use a native gzip library to decompress the data - // For this example, we're assuming the server provides uncompressed data for iOS clients - println("iOS implementation does not support gzip decompression in this example. Assuming data is already decompressed.") + if (data.hasGzipHeader()) { + unsupportedDatasetLoader( + dataset = "Fashion-MNIST", + target = "ios", + reason = "gzip decompression is not implemented; provide an uncompressed IDX URI" + ) + } return@withContext data } diff --git a/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderIos.kt b/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderIos.kt index d39fb55b..c422c9bd 100644 --- a/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderIos.kt +++ b/skainet-data/skainet-data-simple/src/iosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderIos.kt @@ -7,6 +7,8 @@ import io.ktor.client.request.get import io.ktor.client.statement.HttpResponse import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.hasGzipHeader +import sk.ainet.data.common.unsupportedDatasetLoader /** * iOS implementation of the MNIST loader. @@ -29,9 +31,13 @@ public class MNISTLoaderIos(config: MNISTLoaderConfig) : MNISTLoaderCommon(confi println("Downloading file: $url") val data = downloadFile(url) - // Note: In a real implementation, we would use a native gzip library to decompress the data - // For this example, we're assuming the server provides uncompressed data for iOS clients - println("iOS implementation does not support gzip decompression in this example. Assuming data is already decompressed.") + if (data.hasGzipHeader()) { + unsupportedDatasetLoader( + dataset = "MNIST", + target = "ios", + reason = "gzip decompression is not implemented; provide an uncompressed IDX URI" + ) + } return@withContext data } @@ -85,4 +91,4 @@ public class MNISTLoaderIos(config: MNISTLoaderConfig) : MNISTLoaderCommon(confi return MNISTLoaderIos(config) } } -} \ No newline at end of file +} diff --git a/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJs.kt b/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJs.kt index efe9ca1c..3e1fe62f 100644 --- a/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJs.kt +++ b/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJs.kt @@ -2,6 +2,7 @@ package sk.ainet.data.cifar10 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * JS (browser) implementation of the CIFAR-10 loader. @@ -13,7 +14,11 @@ public class CIFAR10LoaderJs(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon( override suspend fun downloadAndExtractBatch(batchFilename: String): ByteArray = withContext(Dispatchers.Default) { - error("CIFAR10LoaderJs is not fully implemented yet. Tar.gz extraction requires JS library support.") + unsupportedDatasetLoader( + dataset = "CIFAR-10", + target = "js", + reason = "tar.gz extraction is not implemented for this browser target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJs.kt b/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJs.kt index d461e02f..114dc37d 100644 --- a/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJs.kt +++ b/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJs.kt @@ -7,9 +7,13 @@ import io.ktor.client.statement.HttpResponse import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import kotlin.js.ExperimentalWasmJsInterop import kotlin.js.Promise import kotlinx.coroutines.await +import sk.ainet.data.common.hasGzipHeader +import sk.ainet.data.common.unsupportedDatasetLoader +@OptIn(ExperimentalWasmJsInterop::class) @JsFun( """ async function(input) { @@ -45,7 +49,13 @@ public class FashionMNISTLoaderJs(config: FashionMNISTLoaderConfig) : FashionMNI if (decompressed != null) { decompressed } else { - println("[FashionMNIST][JS] DecompressionStream not available. Returning raw data (likely gzip) which will fail to parse.") + if (gzData.hasGzipHeader()) { + unsupportedDatasetLoader( + dataset = "Fashion-MNIST", + target = "js", + reason = "browser DecompressionStream is unavailable; provide an uncompressed IDX URI" + ) + } gzData } } diff --git a/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJs.kt b/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJs.kt index 797994cc..3c9051b1 100644 --- a/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJs.kt +++ b/skainet-data/skainet-data-simple/src/jsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJs.kt @@ -7,9 +7,13 @@ import io.ktor.client.statement.HttpResponse import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import kotlin.js.ExperimentalWasmJsInterop import kotlin.js.Promise import kotlinx.coroutines.await +import sk.ainet.data.common.hasGzipHeader +import sk.ainet.data.common.unsupportedDatasetLoader +@OptIn(ExperimentalWasmJsInterop::class) @JsFun( """ async function(input) { @@ -45,7 +49,13 @@ public class MNISTLoaderJs(config: MNISTLoaderConfig) : MNISTLoaderCommon(config if (decompressed != null) { decompressed } else { - println("[MNIST][JS] DecompressionStream not available. Returning raw data (likely gzip) which will fail to parse.") + if (gzData.hasGzipHeader()) { + unsupportedDatasetLoader( + dataset = "MNIST", + target = "js", + reason = "browser DecompressionStream is unavailable; provide an uncompressed IDX URI" + ) + } gzData } } diff --git a/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderLinux.kt b/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderLinux.kt index eb363198..3f06c922 100644 --- a/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderLinux.kt +++ b/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderLinux.kt @@ -2,6 +2,7 @@ package sk.ainet.data.cifar10 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * Linux implementation of the CIFAR-10 loader. @@ -13,7 +14,11 @@ public class CIFAR10LoaderLinux(config: CIFAR10LoaderConfig) : CIFAR10LoaderComm override suspend fun downloadAndExtractBatch(batchFilename: String): ByteArray = withContext(Dispatchers.Default) { - error("CIFAR10LoaderLinux is not implemented yet.") + unsupportedDatasetLoader( + dataset = "CIFAR-10", + target = "linux", + reason = "tar.gz extraction is not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderLinux.kt b/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderLinux.kt index 984cad46..ff1cf0f6 100644 --- a/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderLinux.kt +++ b/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderLinux.kt @@ -2,6 +2,7 @@ package sk.ainet.data.fashionmnist import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * Linux implementation of the Fashion-MNIST loader. @@ -14,7 +15,11 @@ public class FashionMNISTLoaderLinux(config: FashionMNISTLoaderConfig) : Fashion */ override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.Default) { - error("FashionMNISTLoaderLinux.downloadAndCacheFile is not implemented yet.") + unsupportedDatasetLoader( + dataset = "Fashion-MNIST", + target = "linux", + reason = "gzip decompression and cache materialization are not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/mnist/MNISTLoaderLinux.kt b/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/mnist/MNISTLoaderLinux.kt index 11edb857..046219c3 100644 --- a/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/mnist/MNISTLoaderLinux.kt +++ b/skainet-data/skainet-data-simple/src/linuxMain/kotlin/sk/ainet/data/mnist/MNISTLoaderLinux.kt @@ -2,6 +2,7 @@ package sk.ainet.data.mnist import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * Linux implementation of the MNIST loader. @@ -14,7 +15,11 @@ public class MNISTLoaderLinux(config: MNISTLoaderConfig) : MNISTLoaderCommon(con */ override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.Default) { - error("MNISTLoaderLinux.downloadAndCacheFile is not implemented yet.") + unsupportedDatasetLoader( + dataset = "MNIST", + target = "linux", + reason = "gzip decompression and cache materialization are not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderMacos.kt b/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderMacos.kt index 3d1eab89..45839246 100644 --- a/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderMacos.kt +++ b/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderMacos.kt @@ -2,6 +2,7 @@ package sk.ainet.data.cifar10 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * macOS implementation of the CIFAR-10 loader. @@ -13,7 +14,11 @@ public class CIFAR10LoaderMacos(config: CIFAR10LoaderConfig) : CIFAR10LoaderComm override suspend fun downloadAndExtractBatch(batchFilename: String): ByteArray = withContext(Dispatchers.Default) { - error("CIFAR10LoaderMacos is not implemented yet.") + unsupportedDatasetLoader( + dataset = "CIFAR-10", + target = "macos", + reason = "tar.gz extraction is not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderMacos.kt b/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderMacos.kt index e57d5809..5e6938d1 100644 --- a/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderMacos.kt +++ b/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderMacos.kt @@ -2,6 +2,7 @@ package sk.ainet.data.fashionmnist import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * macOS implementation of the Fashion-MNIST loader. @@ -14,7 +15,11 @@ public class FashionMNISTLoaderMacos(config: FashionMNISTLoaderConfig) : Fashion */ override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.Default) { - error("FashionMNISTLoaderMacos.downloadAndCacheFile is not implemented yet. Provide an appleMain implementation or avoid macOS usage for now.") + unsupportedDatasetLoader( + dataset = "Fashion-MNIST", + target = "macos", + reason = "gzip decompression and cache materialization are not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderMacos.kt b/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderMacos.kt index ca976d5f..acfc80ac 100644 --- a/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderMacos.kt +++ b/skainet-data/skainet-data-simple/src/macosMain/kotlin/sk/ainet/data/mnist/MNISTLoaderMacos.kt @@ -2,6 +2,7 @@ package sk.ainet.data.mnist import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * macOS implementation of the MNIST loader. @@ -14,7 +15,11 @@ public class MNISTLoaderMacos(config: MNISTLoaderConfig) : MNISTLoaderCommon(con */ override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.Default) { - error("MNISTLoaderMacos.downloadAndCacheFile is not implemented yet. Provide an appleMain implementation or avoid macOS usage for now.") + unsupportedDatasetLoader( + dataset = "MNIST", + target = "macos", + reason = "gzip decompression and cache materialization are not implemented for this native target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderWasmJs.kt b/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderWasmJs.kt index 85b4648f..2ac5a0f7 100644 --- a/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderWasmJs.kt +++ b/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderWasmJs.kt @@ -2,6 +2,7 @@ package sk.ainet.data.cifar10 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.unsupportedDatasetLoader /** * WASM JS implementation of the CIFAR-10 loader. @@ -13,7 +14,11 @@ public class CIFAR10LoaderWasmJs(config: CIFAR10LoaderConfig) : CIFAR10LoaderCom override suspend fun downloadAndExtractBatch(batchFilename: String): ByteArray = withContext(Dispatchers.Default) { - error("CIFAR10LoaderWasmJs is not implemented yet.") + unsupportedDatasetLoader( + dataset = "CIFAR-10", + target = "wasmJs", + reason = "tar.gz extraction is not implemented for this browser target" + ) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderWasmJs.kt b/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderWasmJs.kt index beab29b4..3af47dcc 100644 --- a/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderWasmJs.kt +++ b/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderWasmJs.kt @@ -7,6 +7,8 @@ import io.ktor.client.statement.HttpResponse import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.hasGzipHeader +import sk.ainet.data.common.unsupportedDatasetLoader /** * WASM JS implementation of the Fashion-MNIST loader. @@ -29,9 +31,13 @@ public class FashionMNISTLoaderWasmJs(config: FashionMNISTLoaderConfig) : Fashio println("Downloading Fashion-MNIST file: $url") val data = downloadFile(url) - // Note: In a real implementation, we would use a JS gzip library to decompress the data - // For this example, we're assuming the server provides uncompressed data for WASM clients - println("WASM JS implementation does not support gzip decompression. Assuming data is already decompressed.") + if (data.hasGzipHeader()) { + unsupportedDatasetLoader( + dataset = "Fashion-MNIST", + target = "wasmJs", + reason = "gzip decompression is not implemented; provide an uncompressed IDX URI" + ) + } return@withContext data } diff --git a/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderWasmJs.kt b/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderWasmJs.kt index 3886bdff..78b6bbe7 100644 --- a/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderWasmJs.kt +++ b/skainet-data/skainet-data-simple/src/wasmJsMain/kotlin/sk/ainet/data/mnist/MNISTLoaderWasmJs.kt @@ -7,6 +7,8 @@ import io.ktor.client.statement.HttpResponse import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.hasGzipHeader +import sk.ainet.data.common.unsupportedDatasetLoader /** * WASM JS implementation of the MNIST loader. @@ -29,9 +31,13 @@ public class MNISTLoaderWasmJs(config: MNISTLoaderConfig) : MNISTLoaderCommon(co println("Downloading file: $url") val data = downloadFile(url) - // Note: In a real implementation, we would use a JS gzip library to decompress the data - // For this example, we're assuming the server provides uncompressed data for WASM clients - println("WASM JS implementation does not support gzip decompression. Assuming data is already decompressed.") + if (data.hasGzipHeader()) { + unsupportedDatasetLoader( + dataset = "MNIST", + target = "wasmJs", + reason = "gzip decompression is not implemented; provide an uncompressed IDX URI" + ) + } return@withContext data } From 34ff3fbcc66c4a9d165ab339534c08fd04a7570a Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 20:39:59 +0200 Subject: [PATCH 14/18] data: add raw format parsers --- .../skainet-data-source/build.gradle.kts | 1 + .../sk/ainet/data/source/DataFormatParser.kt | 169 ++++++++++++++++++ .../ainet/data/source/DataFormatParserTest.kt | 90 ++++++++++ 3 files changed, 260 insertions(+) create mode 100644 skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt create mode 100644 skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataFormatParserTest.kt diff --git a/skainet-data/skainet-data-source/build.gradle.kts b/skainet-data/skainet-data-source/build.gradle.kts index 7d7c3065..bcd69b4e 100644 --- a/skainet-data/skainet-data-source/build.gradle.kts +++ b/skainet-data/skainet-data-source/build.gradle.kts @@ -19,6 +19,7 @@ kotlin { commonMain.dependencies { implementation(libs.kotlinx.coroutines) implementation(libs.kotlinx.io.core) + implementation(libs.kotlinx.serialization.json) } commonTest.dependencies { diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt new file mode 100644 index 00000000..5a36877e --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt @@ -0,0 +1,169 @@ +package sk.ainet.data.source + +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive + +/** Supported built-in raw dataset formats. */ +public enum class DataFormat(public val extensions: Set) { + CSV(setOf("csv")), + TSV(setOf("tsv")), + JSON_LINES(setOf("jsonl", "ndjson")) +} + +/** A simple schema inferred from raw stringly parsed data. */ +public data class DataSchema( + public val columns: List +) + +/** One parsed row from a raw tabular or JSON-lines dataset. */ +public data class RawDataRow( + public val values: Map +) + +/** Parsed raw dataset plus schema and lightweight provenance metadata. */ +public data class RawDataset( + public val rows: List, + public val schema: DataSchema, + public val metadata: Map = emptyMap() +) + +/** Parser contract for converting source text into a [RawDataset]. */ +public interface DataFormatParser { + public val format: DataFormat + + public fun parse(text: String): RawDataset +} + +/** Registry for built-in and user-provided data format parsers. */ +public class DataFormatParserRegistry( + parsers: Iterable = defaultDataFormatParsers() +) { + private val parsersByFormat: MutableMap = + parsers.associateBy { it.format }.toMutableMap() + + public fun register(parser: DataFormatParser) { + parsersByFormat[parser.format] = parser + } + + public fun parserFor(format: DataFormat): DataFormatParser = + parsersByFormat[format] ?: throw DataSourceException("No parser registered for $format") + + public fun parse(format: DataFormat, text: String): RawDataset = + parserFor(format).parse(text) + + public companion object { + public fun default(): DataFormatParserRegistry = DataFormatParserRegistry() + } +} + +/** Returns the default built-in parser set. */ +public fun defaultDataFormatParsers(): List = + listOf( + DelimitedTextDataFormatParser(DataFormat.CSV, delimiter = ','), + DelimitedTextDataFormatParser(DataFormat.TSV, delimiter = '\t'), + JsonLinesDataFormatParser() + ) + +/** Parser for simple delimited text with a required header row. */ +public class DelimitedTextDataFormatParser( + override val format: DataFormat, + private val delimiter: Char +) : DataFormatParser { + override fun parse(text: String): RawDataset { + val lines = text.lineSequence() + .map { it.trimEnd('\r') } + .filter { it.isNotBlank() } + .toList() + require(lines.isNotEmpty()) { "$format input must contain a header row" } + + val columns = splitDelimitedLine(lines.first(), delimiter) + require(columns.all { it.isNotBlank() }) { "$format header columns must not be blank" } + + val rows = lines.drop(1).mapIndexed { index, line -> + val values = splitDelimitedLine(line, delimiter) + require(values.size == columns.size) { + "$format row ${index + 2} has ${values.size} values but expected ${columns.size}" + } + RawDataRow(columns.zip(values).toMap()) + } + + return RawDataset( + rows = rows, + schema = DataSchema(columns), + metadata = mapOf("format" to format.name, "rowCount" to rows.size.toString()) + ) + } +} + +/** Parser for newline-delimited JSON objects. */ +public class JsonLinesDataFormatParser( + private val json: Json = Json +) : DataFormatParser { + override val format: DataFormat = DataFormat.JSON_LINES + + override fun parse(text: String): RawDataset { + val objects = text.lineSequence() + .map { it.trim() } + .filter { it.isNotEmpty() } + .mapIndexed { index, line -> + val element = json.parseToJsonElement(line) + require(element is JsonObject) { "JSON_LINES row ${index + 1} must be a JSON object" } + element + } + .toList() + + val columns = objects + .flatMap { it.keys } + .distinct() + + val rows = objects.map { obj -> + RawDataRow(columns.associateWith { column -> obj[column]?.toRawString().orEmpty() }) + } + + return RawDataset( + rows = rows, + schema = DataSchema(columns), + metadata = mapOf("format" to format.name, "rowCount" to rows.size.toString()) + ) + } +} + +private fun splitDelimitedLine(line: String, delimiter: Char): List { + val values = mutableListOf() + val current = StringBuilder() + var quoted = false + var index = 0 + + while (index < line.length) { + val char = line[index] + when { + char == '"' && quoted && index + 1 < line.length && line[index + 1] == '"' -> { + current.append('"') + index++ + } + char == '"' -> quoted = !quoted + char == delimiter && !quoted -> { + values.add(current.toString()) + current.clear() + } + else -> current.append(char) + } + index++ + } + + require(!quoted) { "Unterminated quoted field" } + values.add(current.toString()) + return values +} + +private fun JsonElement.toRawString(): String = + when (this) { + JsonNull -> "" + is JsonPrimitive -> content + is JsonArray -> toString() + is JsonObject -> toString() + } diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataFormatParserTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataFormatParserTest.kt new file mode 100644 index 00000000..8f7c33c1 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataFormatParserTest.kt @@ -0,0 +1,90 @@ +package sk.ainet.data.source + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertSame + +class DataFormatParserTest { + @Test + fun parsesCsvHeaderAndQuotedValues() { + val dataset = DataFormatParserRegistry.default().parse( + DataFormat.CSV, + "name,label,comment\n" + + "Ada,math,\"first, second\"\n" + + "Grace,compiler,\"said \"\"hello\"\"\"" + ) + + assertEquals(listOf("name", "label", "comment"), dataset.schema.columns) + assertEquals(2, dataset.rows.size) + assertEquals("Ada", dataset.rows[0].values["name"]) + assertEquals("first, second", dataset.rows[0].values["comment"]) + assertEquals("said \"hello\"", dataset.rows[1].values["comment"]) + assertEquals("CSV", dataset.metadata["format"]) + assertEquals("2", dataset.metadata["rowCount"]) + } + + @Test + fun parsesTsvRows() { + val dataset = DataFormatParserRegistry.default().parse( + DataFormat.TSV, + "id\tvalue\n" + + "1\t42\n" + + "2\t84" + ) + + assertEquals(listOf("id", "value"), dataset.schema.columns) + assertEquals( + mapOf("id" to "2", "value" to "84"), + dataset.rows[1].values + ) + } + + @Test + fun parsesJsonLinesWithUnionSchema() { + val dataset = DataFormatParserRegistry.default().parse( + DataFormat.JSON_LINES, + "{\"id\":1,\"label\":\"cat\",\"pixels\":[0,1],\"meta\":{\"split\":\"train\"}}\n" + + "{\"id\":2,\"label\":\"dog\",\"score\":0.5}" + ) + + assertEquals(listOf("id", "label", "pixels", "meta", "score"), dataset.schema.columns) + assertEquals("1", dataset.rows[0].values["id"]) + assertEquals("[0,1]", dataset.rows[0].values["pixels"]) + assertEquals("{\"split\":\"train\"}", dataset.rows[0].values["meta"]) + assertEquals("", dataset.rows[0].values["score"]) + assertEquals("0.5", dataset.rows[1].values["score"]) + } + + @Test + fun replacesRegisteredParser() { + val registry = DataFormatParserRegistry(parsers = emptyList()) + val parser = object : DataFormatParser { + override val format: DataFormat = DataFormat.CSV + + override fun parse(text: String): RawDataset = + RawDataset( + rows = listOf(RawDataRow(mapOf("value" to text))), + schema = DataSchema(listOf("value")) + ) + } + + registry.register(parser) + + assertSame(parser, registry.parserFor(DataFormat.CSV)) + assertEquals("payload", registry.parse(DataFormat.CSV, "payload").rows.single().values["value"]) + assertFailsWith { + registry.parserFor(DataFormat.TSV) + } + } + + @Test + fun rejectsDelimitedRowsWithWrongWidth() { + assertFailsWith { + DataFormatParserRegistry.default().parse( + DataFormat.CSV, + "a,b\n1,2,3" + ) + } + } +} From a17266ab3ffe2b8d7e9399cfeaae7585c4c5436e Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 20:41:53 +0200 Subject: [PATCH 15/18] data: parse json raw datasets --- .../sk/ainet/data/source/DataFormatParser.kt | 48 ++++++++++++++----- .../ainet/data/source/DataFormatParserTest.kt | 35 ++++++++++++++ 2 files changed, 71 insertions(+), 12 deletions(-) diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt index 5a36877e..4db2a7e4 100644 --- a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt @@ -11,6 +11,7 @@ import kotlinx.serialization.json.JsonPrimitive public enum class DataFormat(public val extensions: Set) { CSV(setOf("csv")), TSV(setOf("tsv")), + JSON(setOf("json")), JSON_LINES(setOf("jsonl", "ndjson")) } @@ -65,6 +66,7 @@ public fun defaultDataFormatParsers(): List = listOf( DelimitedTextDataFormatParser(DataFormat.CSV, delimiter = ','), DelimitedTextDataFormatParser(DataFormat.TSV, delimiter = '\t'), + JsonDataFormatParser(), JsonLinesDataFormatParser() ) @@ -99,6 +101,27 @@ public class DelimitedTextDataFormatParser( } } +/** Parser for JSON object datasets encoded as one object or an array of objects. */ +public class JsonDataFormatParser( + private val json: Json = Json +) : DataFormatParser { + override val format: DataFormat = DataFormat.JSON + + override fun parse(text: String): RawDataset { + val root = json.parseToJsonElement(text) + val objects = when (root) { + is JsonObject -> listOf(root) + is JsonArray -> root.mapIndexed { index, element -> + require(element is JsonObject) { "JSON array element ${index + 1} must be an object" } + element + } + else -> throw IllegalArgumentException("JSON input must be an object or array of objects") + } + + return objects.toRawDataset(format) + } +} + /** Parser for newline-delimited JSON objects. */ public class JsonLinesDataFormatParser( private val json: Json = Json @@ -116,20 +139,21 @@ public class JsonLinesDataFormatParser( } .toList() - val columns = objects - .flatMap { it.keys } - .distinct() - - val rows = objects.map { obj -> - RawDataRow(columns.associateWith { column -> obj[column]?.toRawString().orEmpty() }) - } + return objects.toRawDataset(format) + } +} - return RawDataset( - rows = rows, - schema = DataSchema(columns), - metadata = mapOf("format" to format.name, "rowCount" to rows.size.toString()) - ) +private fun List.toRawDataset(format: DataFormat): RawDataset { + val columns = flatMap { it.keys }.distinct() + val rows = map { obj -> + RawDataRow(columns.associateWith { column -> obj[column]?.toRawString().orEmpty() }) } + + return RawDataset( + rows = rows, + schema = DataSchema(columns), + metadata = mapOf("format" to format.name, "rowCount" to rows.size.toString()) + ) } private fun splitDelimitedLine(line: String, delimiter: Char): List { diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataFormatParserTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataFormatParserTest.kt index 8f7c33c1..788d3c38 100644 --- a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataFormatParserTest.kt +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataFormatParserTest.kt @@ -56,6 +56,41 @@ class DataFormatParserTest { assertEquals("0.5", dataset.rows[1].values["score"]) } + @Test + fun parsesJsonArrayWithUnionSchema() { + val dataset = DataFormatParserRegistry.default().parse( + DataFormat.JSON, + "[" + + "{\"id\":1,\"label\":\"cat\",\"pixels\":[0,1]}," + + "{\"id\":2,\"label\":\"dog\",\"score\":0.5}" + + "]" + ) + + assertEquals(listOf("id", "label", "pixels", "score"), dataset.schema.columns) + assertEquals("JSON", dataset.metadata["format"]) + assertEquals("2", dataset.metadata["rowCount"]) + assertEquals("[0,1]", dataset.rows[0].values["pixels"]) + assertEquals("", dataset.rows[0].values["score"]) + } + + @Test + fun parsesJsonSingleObject() { + val dataset = DataFormatParserRegistry.default().parse( + DataFormat.JSON, + "{\"id\":1,\"label\":\"cat\"}" + ) + + assertEquals(listOf("id", "label"), dataset.schema.columns) + assertEquals(mapOf("id" to "1", "label" to "cat"), dataset.rows.single().values) + } + + @Test + fun rejectsJsonArraysWithNonObjectElements() { + assertFailsWith { + DataFormatParserRegistry.default().parse(DataFormat.JSON, "[1]") + } + } + @Test fun replacesRegisteredParser() { val registry = DataFormatParserRegistry(parsers = emptyList()) From 5655c8642398abc3d7d83732862d4a2de56232a0 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 20:44:01 +0200 Subject: [PATCH 16/18] data: add data source dataset builder --- .../sk/ainet/data/source/DataFormatParser.kt | 20 ++- .../data/source/DataSourceDatasetBuilder.kt | 132 ++++++++++++++++++ .../source/DataSourceDatasetBuilderTest.kt | 125 +++++++++++++++++ 3 files changed, 276 insertions(+), 1 deletion(-) create mode 100644 skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceDatasetBuilder.kt create mode 100644 skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceDatasetBuilderTest.kt diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt index 4db2a7e4..d6bf7ec0 100644 --- a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataFormatParser.kt @@ -12,7 +12,25 @@ public enum class DataFormat(public val extensions: Set) { CSV(setOf("csv")), TSV(setOf("tsv")), JSON(setOf("json")), - JSON_LINES(setOf("jsonl", "ndjson")) + JSON_LINES(setOf("jsonl", "ndjson")); + + public companion object { + public fun fromExtension(extension: String): DataFormat? { + val normalized = extension.trim().lowercase().removePrefix(".") + if (normalized.isBlank()) return null + return entries.firstOrNull { format -> normalized in format.extensions } + } + + public fun inferFromFilename(filename: String): DataFormat? { + val normalized = filename + .substringBefore('?') + .substringBefore('#') + .trimEnd('/') + .substringAfterLast('/') + val extension = normalized.substringAfterLast('.', missingDelimiterValue = "") + return fromExtension(extension) + } + } } /** A simple schema inferred from raw stringly parsed data. */ diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceDatasetBuilder.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceDatasetBuilder.kt new file mode 100644 index 00000000..0e7b1bd7 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceDatasetBuilder.kt @@ -0,0 +1,132 @@ +package sk.ainet.data.source + +/** + * Kotlin DSL builder for resolving a data source artifact and parsing it into + * a raw, schema-bearing dataset. + */ +public class DataSourceDatasetBuilder( + resolver: DataSourceResolver? = null, + parserRegistry: DataFormatParserRegistry = DataFormatParserRegistry.default() +) { + private var resolver: DataSourceResolver? = resolver + private var parserRegistry: DataFormatParserRegistry = parserRegistry + private var sourceUri: String? = null + private var requestedFormat: DataFormat? = null + private var cachePolicy: CachePolicy = CachePolicy.Use + private var expectedSha256: String? = null + private val requestHeaders: MutableMap = linkedMapOf() + private var huggingFaceToken: DataSourceAuthToken? = null + + /** Selects the source artifact URI or local path. */ + public fun from(uri: String): DataSourceDatasetBuilder = apply { + val normalized = uri.trim() + require(normalized.isNotEmpty()) { "Data source URI must not be blank" } + sourceUri = normalized + } + + /** Overrides format inference from the source filename. */ + public fun format(format: DataFormat): DataSourceDatasetBuilder = apply { + requestedFormat = format + } + + /** Sets resolver cache behavior for this load. */ + public fun cachePolicy(cachePolicy: CachePolicy): DataSourceDatasetBuilder = apply { + this.cachePolicy = cachePolicy + } + + /** Sets an optional expected SHA-256 checksum for resolver verification. */ + public fun expectedSha256(value: String?): DataSourceDatasetBuilder = apply { + expectedSha256 = value?.trim()?.takeIf { it.isNotEmpty() } + } + + /** Adds one request header. */ + public fun header(name: String, value: String): DataSourceDatasetBuilder = apply { + val normalizedName = name.trim() + require(normalizedName.isNotEmpty()) { "Header name must not be blank" } + requestHeaders[normalizedName] = value + } + + /** Adds request headers, replacing entries with the same name. */ + public fun headers(headers: Map): DataSourceDatasetBuilder = apply { + headers.forEach { (name, value) -> header(name, value) } + } + + /** Sets a provider-specific Hugging Face token for this request. */ + public fun huggingFaceToken(token: DataSourceAuthToken?): DataSourceDatasetBuilder = apply { + huggingFaceToken = token + } + + /** Sets a provider-specific Hugging Face token for this request. */ + public fun huggingFaceToken(value: String): DataSourceDatasetBuilder = + huggingFaceToken(DataSourceAuthToken.from(value)) + + /** Supplies the source resolver used to materialize the artifact. */ + public fun resolver(resolver: DataSourceResolver): DataSourceDatasetBuilder = apply { + this.resolver = resolver + } + + /** Replaces the parser registry used by this builder. */ + public fun parserRegistry(parserRegistry: DataFormatParserRegistry): DataSourceDatasetBuilder = apply { + this.parserRegistry = parserRegistry + } + + /** Registers or replaces one parser in this builder's parser registry. */ + public fun parser(parser: DataFormatParser): DataSourceDatasetBuilder = apply { + parserRegistry.register(parser) + } + + /** Resolves and parses the configured source artifact. */ + public suspend fun build(): RawDataset { + val request = buildRequest() + val artifact = requireResolver().resolve(request) + val format = requestedFormat + ?: DataFormat.inferFromFilename(artifact.filename) + ?: throw DataSourceException( + "Cannot infer data format from '${artifact.filename}'. Specify format(...) explicitly." + ) + val text = artifact.readBytes().decodeToString() + return parserRegistry.parse(format, text).withSourceMetadata(artifact) + } + + private fun requireResolver(): DataSourceResolver = + resolver ?: throw DataSourceException("A DataSourceResolver is required to load a data source dataset") + + private fun buildRequest(): DataSourceRequest { + val uri = sourceUri ?: throw DataSourceException("Data source URI is required; call from(...) first") + return DataSourceRequest( + uri = uri, + cachePolicy = cachePolicy, + expectedSha256 = expectedSha256, + headers = requestHeaders.toMap(), + huggingFaceToken = huggingFaceToken + ) + } + + private fun RawDataset.withSourceMetadata(artifact: DataSourceArtifact): RawDataset { + val sourceMetadata = mutableMapOf( + "sourceUri" to artifact.request.uri, + "sourceProvider" to artifact.parsedUri.provider.name, + "sourceFilename" to artifact.filename, + "sourceCacheHit" to artifact.cacheHit.toString() + ) + artifact.localPath?.let { sourceMetadata["sourceLocalPath"] = it } + artifact.sizeBytes?.let { sourceMetadata["sourceSizeBytes"] = it.toString() } + return copy(metadata = metadata + sourceMetadata) + } +} + +/** Builds a raw dataset by resolving and parsing a configured data source. */ +public suspend fun rawDataset(block: DataSourceDatasetBuilder.() -> Unit): RawDataset = + DataSourceDatasetBuilder().apply(block).build() + +/** Kotlinish alias for [rawDataset]. */ +public suspend fun dataset(block: DataSourceDatasetBuilder.() -> Unit): RawDataset = + rawDataset(block) + +/** Builds a raw dataset with this resolver preconfigured. */ +public suspend fun DataSourceResolver.rawDataset(block: DataSourceDatasetBuilder.() -> Unit): RawDataset = + DataSourceDatasetBuilder(this).apply(block).build() + +/** Kotlinish alias for [DataSourceResolver.rawDataset]. */ +public suspend fun DataSourceResolver.dataset(block: DataSourceDatasetBuilder.() -> Unit): RawDataset = + rawDataset(block) diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceDatasetBuilderTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceDatasetBuilderTest.kt new file mode 100644 index 00000000..ee6ec352 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceDatasetBuilderTest.kt @@ -0,0 +1,125 @@ +package sk.ainet.data.source + +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class DataSourceDatasetBuilderTest { + @Test + fun resolvesSourceAndInfersCsvFormat() = runTest { + val resolver = FixtureResolver("id,label\n1,cat\n2,dog") + val token = DataSourceAuthToken.from("hf_token") + + val dataset = resolver.rawDataset { + from("hf://datasets/org/repo@main/train.csv") + cachePolicy(CachePolicy.Refresh) + expectedSha256("sha256") + header("Accept", "text/csv") + huggingFaceToken(token) + } + + assertEquals(listOf("id", "label"), dataset.schema.columns) + assertEquals(mapOf("id" to "2", "label" to "dog"), dataset.rows[1].values) + assertEquals("CSV", dataset.metadata["format"]) + assertEquals("hf://datasets/org/repo@main/train.csv", dataset.metadata["sourceUri"]) + assertEquals("HuggingFace", dataset.metadata["sourceProvider"]) + assertEquals("train.csv", dataset.metadata["sourceFilename"]) + assertEquals("false", dataset.metadata["sourceCacheHit"]) + assertEquals("20", dataset.metadata["sourceSizeBytes"]) + + val request = resolver.lastRequest + assertEquals(CachePolicy.Refresh, request?.cachePolicy) + assertEquals("sha256", request?.expectedSha256) + assertEquals(mapOf("Accept" to "text/csv"), request?.headers) + assertEquals(token, request?.huggingFaceToken) + } + + @Test + fun explicitFormatOverridesFilenameInference() = runTest { + val resolver = FixtureResolver("id\tvalue\n1\t42") + + val dataset = resolver.rawDataset { + from("fixtures/train.txt") + format(DataFormat.TSV) + } + + assertEquals(listOf("id", "value"), dataset.schema.columns) + assertEquals(mapOf("id" to "1", "value" to "42"), dataset.rows.single().values) + } + + @Test + fun parserRegistrationReplacesDefaultParser() = runTest { + val resolver = FixtureResolver("payload") + val customParser = object : DataFormatParser { + override val format: DataFormat = DataFormat.CSV + + override fun parse(text: String): RawDataset = + RawDataset( + rows = listOf(RawDataRow(mapOf("raw" to text.uppercase()))), + schema = DataSchema(listOf("raw")) + ) + } + + val dataset = resolver.rawDataset { + from("fixture.csv") + parser(customParser) + } + + assertEquals(mapOf("raw" to "PAYLOAD"), dataset.rows.single().values) + } + + @Test + fun failsWhenFormatCannotBeInferred() = runTest { + val resolver = FixtureResolver("payload") + + assertFailsWith { + resolver.rawDataset { + from("fixture.bin") + } + } + } + + @Test + fun failsWhenResolverIsMissing() = runTest { + assertFailsWith { + rawDataset { + from("fixture.csv") + } + } + } + + @Test + fun failsWhenSourceIsMissing() = runTest { + val resolver = FixtureResolver("payload") + + assertFailsWith { + resolver.rawDataset { + format(DataFormat.CSV) + } + } + } +} + +private class FixtureResolver( + text: String +) : DataSourceResolver { + private val bytes = text.encodeToByteArray() + + var lastRequest: DataSourceRequest? = null + private set + + override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact { + lastRequest = request + val parsed = DataSourceUriParser.parse(request.uri) + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = parsed.localPath, + sizeBytes = bytes.size.toLong(), + cacheHit = false, + sourceOpener = { DataSourceStoredArtifact.inMemory(bytes, parsed.localPath).openSource() } + ) + } +} From b0eee980977266471612f5dd267190a39cb312d3 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 20:45:52 +0200 Subject: [PATCH 17/18] data: add suspend data pipeline DSL --- .../sk/ainet/data/source/DataPipeline.kt | 111 ++++++++++++++++++ .../sk/ainet/data/source/DataPipelineTest.kt | 68 +++++++++++ 2 files changed, 179 insertions(+) create mode 100644 skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataPipeline.kt create mode 100644 skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataPipelineTest.kt diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataPipeline.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataPipeline.kt new file mode 100644 index 00000000..82b26df1 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataPipeline.kt @@ -0,0 +1,111 @@ +package sk.ainet.data.source + +/** A named suspendable data processing stage. */ +public interface PipelineStage { + public val name: String + + public fun validate(input: I): Boolean = true + + public suspend fun process(input: I): O +} + +/** A schema-aware stage for data preprocessing pipelines. */ +public interface DataTransformer : PipelineStage { + public suspend fun transform(input: I): O + + public fun getOutputSchema(inputSchema: DataSchema): DataSchema = inputSchema + + override suspend fun process(input: I): O = transform(input) +} + +/** Thrown when a data pipeline cannot execute a stage. */ +public class DataPipelineException( + message: String, + cause: Throwable? = null +) : DataSourceException(message, cause) + +/** A type-safe sequential data pipeline. */ +public class DataPipeline internal constructor( + private val stages: List> +) { + public val stageNames: List = stages.map { stage -> stage.name } + + /** Adds [stage] to the end of this pipeline. */ + public fun stage(stage: PipelineStage): DataPipeline = + DataPipeline(stages + stage) + + /** Adds [stage] to the end of this pipeline. */ + public infix fun then(stage: PipelineStage): DataPipeline = + this.stage(stage) + + /** Executes each stage in order. */ + @Suppress("UNCHECKED_CAST") + public suspend fun execute(input: I): O { + var current: Any? = input + for (stage in stages) { + val typedStage = stage as PipelineStage + if (!typedStage.validate(current)) { + throw DataPipelineException("Stage '${stage.name}' rejected its input") + } + current = typedStage.process(current) + } + return current as O + } + + /** Returns a human-readable stage chain. */ + public fun describe(): String = stageNames.joinToString(" -> ") +} + +/** Starts an identity data pipeline. */ +public fun dataPipeline(): DataPipeline = DataPipeline(emptyList()) + +/** Creates a named suspendable pipeline stage. */ +public fun pipelineStage( + name: String, + validate: (I) -> Boolean = { true }, + process: suspend (I) -> O +): PipelineStage { + require(name.isNotBlank()) { "Pipeline stage name must not be blank" } + return FunctionPipelineStage(name, validate, process) +} + +/** Kotlinish alias for [pipelineStage]. */ +public fun stage( + name: String, + validate: (I) -> Boolean = { true }, + process: suspend (I) -> O +): PipelineStage = pipelineStage(name, validate, process) + +/** Creates a named schema-aware data transformer. */ +public fun dataTransformer( + name: String, + outputSchema: (DataSchema) -> DataSchema = { it }, + validate: (I) -> Boolean = { true }, + transform: suspend (I) -> O +): DataTransformer { + require(name.isNotBlank()) { "Data transformer name must not be blank" } + return FunctionDataTransformer(name, validate, outputSchema, transform) +} + +private class FunctionPipelineStage( + override val name: String, + private val validateInput: (I) -> Boolean, + private val processInput: suspend (I) -> O +) : PipelineStage { + override fun validate(input: I): Boolean = validateInput(input) + + override suspend fun process(input: I): O = processInput(input) +} + +private class FunctionDataTransformer( + override val name: String, + private val validateInput: (I) -> Boolean, + private val outputSchema: (DataSchema) -> DataSchema, + private val transformInput: suspend (I) -> O +) : DataTransformer { + override fun validate(input: I): Boolean = validateInput(input) + + override suspend fun transform(input: I): O = transformInput(input) + + override fun getOutputSchema(inputSchema: DataSchema): DataSchema = outputSchema(inputSchema) +} diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataPipelineTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataPipelineTest.kt new file mode 100644 index 00000000..57693eba --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataPipelineTest.kt @@ -0,0 +1,68 @@ +package sk.ainet.data.source + +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class DataPipelineTest { + @Test + fun executesTypedStagesInOrder() = runTest { + val pipeline = dataPipeline() + .stage(stage("add-one") { input -> input + 1 }) + .stage(stage("stringify") { input -> "value=$input" }) + + assertEquals("add-one -> stringify", pipeline.describe()) + assertEquals(listOf("add-one", "stringify"), pipeline.stageNames) + assertEquals("value=42", pipeline.execute(41)) + } + + @Test + fun identityPipelineReturnsInput() = runTest { + val pipeline = dataPipeline() + + assertEquals("", pipeline.describe()) + assertEquals(7, pipeline.execute(7)) + } + + @Test + fun rejectsInvalidStageInput() = runTest { + val pipeline = dataPipeline() + .stage( + stage( + name = "positive", + validate = { input -> input > 0 } + ) { input -> input } + ) + + assertFailsWith { + pipeline.execute(0) + } + } + + @Test + fun transformerUpdatesSchemaAndRows() = runTest { + val dropLabel = dataTransformer( + name = "drop-label", + outputSchema = { schema -> DataSchema(schema.columns.filter { column -> column != "label" }) } + ) { dataset -> + val columns = dataset.schema.columns.filter { column -> column != "label" } + dataset.copy( + schema = DataSchema(columns), + rows = dataset.rows.map { row -> + RawDataRow(row.values.filterKeys { key -> key in columns }) + } + ) + } + val input = RawDataset( + rows = listOf(RawDataRow(mapOf("id" to "1", "label" to "cat"))), + schema = DataSchema(listOf("id", "label")) + ) + + val output = (dataPipeline() then dropLabel).execute(input) + + assertEquals(DataSchema(listOf("id")), dropLabel.getOutputSchema(input.schema)) + assertEquals(DataSchema(listOf("id")), output.schema) + assertEquals(mapOf("id" to "1"), output.rows.single().values) + } +} From ce4640f9b53e7427c209cfcec855ac6d7cf97fb6 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 20:46:38 +0200 Subject: [PATCH 18/18] docs: document data loader APIs --- README.md | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 56eda582..587173de 100644 --- a/README.md +++ b/README.md @@ -213,8 +213,35 @@ Runnable examples: - Built-in loaders: MNIST, Fashion-MNIST, CIFAR-10 - URI-backed data sources: `file://`, `https://`, `hf+https://`, and `hf://...` +- Dataset operations: deterministic shuffle/split, stratified split, filter/map/transform views, batch flows, and epoch flows +- Raw dataset parsers: CSV, TSV, JSON arrays/objects, JSON Lines (`.jsonl`, `.ndjson`) +- Type-safe transform DSLs: image/tensor transforms plus suspendable raw data pipelines - Formats: GGUF, ONNX, SafeTensors, JSON, Image (JPEG, PNG) -- Type-safe transform DSL: resize, crop, normalize, toTensor + +```kotlin +val raw = JvmDataSourceResolver().rawDataset { + from("hf://datasets/org/repo@main/train.jsonl") + format(DataFormat.JSON_LINES) + cachePolicy(CachePolicy.Use) +} + +val withoutLabel = dataPipeline() + .stage( + dataTransformer( + name = "drop-label", + outputSchema = { schema -> DataSchema(schema.columns - "label") } + ) { dataset -> + val columns = dataset.schema.columns - "label" + dataset.copy( + schema = DataSchema(columns), + rows = dataset.rows.map { row -> + RawDataRow(row.values.filterKeys { key -> key in columns }) + } + ) + } + ) + .execute(raw) +``` ### Edge AI: Arduino / C99 Export