From 717cf85ef53af35bfa4435cdc189ba577dab65aa Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 13:36:55 +0200 Subject: [PATCH 1/7] data: add URI source contracts --- settings.gradle.kts | 1 + .../skainet-data-source/build.gradle.kts | 38 ++++ .../skainet-data-source/gradle.properties | 2 + .../sk/ainet/data/source/DataSourceModels.kt | 102 +++++++++ .../ainet/data/source/DataSourceUriParser.kt | 200 ++++++++++++++++++ .../data/source/DataSourceUriParserTest.kt | 80 +++++++ 6 files changed, 423 insertions(+) create mode 100644 skainet-data/skainet-data-source/build.gradle.kts create mode 100644 skainet-data/skainet-data-source/gradle.properties create mode 100644 skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt create mode 100644 skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt create mode 100644 skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt diff --git a/settings.gradle.kts b/settings.gradle.kts index bbfbc825..5e25ef39 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -48,6 +48,7 @@ include("skainet-backends:benchmarks:jvm-cpu-publish") // ====== DATA include("skainet-data:skainet-data-api") +include("skainet-data:skainet-data-source") include("skainet-data:skainet-data-transform") include("skainet-data:skainet-data-simple") include("skainet-data:skainet-data-media") diff --git a/skainet-data/skainet-data-source/build.gradle.kts b/skainet-data/skainet-data-source/build.gradle.kts new file mode 100644 index 00000000..f8b4dbc0 --- /dev/null +++ b/skainet-data/skainet-data-source/build.gradle.kts @@ -0,0 +1,38 @@ +import org.jetbrains.kotlin.gradle.dsl.JvmTarget + +plugins { + alias(libs.plugins.kotlinMultiplatform) + alias(libs.plugins.vanniktech.mavenPublish) + id("sk.ainet.dokka") +} + +kotlin { + explicitApi() + + jvm { + compilerOptions { + jvmTarget.set(JvmTarget.JVM_11) + } + } + + sourceSets { + commonMain.dependencies { + implementation(libs.kotlinx.coroutines) + } + + commonTest.dependencies { + implementation(libs.kotlin.test) + } + + jvmMain.dependencies { + implementation(libs.ktor.client.cio) + implementation(libs.ktor.client.core) + implementation(libs.ktor.client.plugins) + implementation(libs.kotlinx.coroutines.core.jvm) + } + + jvmTest.dependencies { + implementation(libs.kotlinx.coroutines.test) + } + } +} diff --git a/skainet-data/skainet-data-source/gradle.properties b/skainet-data/skainet-data-source/gradle.properties new file mode 100644 index 00000000..3516f9dd --- /dev/null +++ b/skainet-data/skainet-data-source/gradle.properties @@ -0,0 +1,2 @@ +POM_ARTIFACT_ID=skainet-data-source +POM_NAME=skainet data source diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt new file mode 100644 index 00000000..71eacf32 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt @@ -0,0 +1,102 @@ +package sk.ainet.data.source + +/** + * Cache behavior requested by a caller resolving a data artifact. + */ +public enum class CachePolicy { + /** Use a cached artifact when present, otherwise fetch and cache it. */ + Use, + + /** Fetch the artifact again and replace any existing cached copy. */ + Refresh, + + /** Require an existing cached or local artifact; do not use the network. */ + Offline, + + /** Fetch or read the artifact without writing a persistent cache entry. */ + Bypass +} + +/** + * High-level provider implied by a source URI. + */ +public enum class DataSourceProvider { + File, + Http, + HuggingFace +} + +/** + * Hugging Face repository namespace encoded by an `hf://` URI. + */ +public enum class HuggingFaceRepoType { + Model, + Dataset, + Space +} + +/** + * Parsed Hugging Face location, when a URI uses SKaiNET's `hf://` shorthand + * or the explicit `hf+https://...` provider prefix. + */ +public data class HuggingFaceLocation( + public val repoType: HuggingFaceRepoType, + public val repoId: String?, + public val revision: String?, + public val path: String? +) + +/** + * A normalized, provider-aware source URI. + */ +public data class ParsedDataSourceUri( + public val rawUri: String, + public val provider: DataSourceProvider, + public val transportUri: String, + public val filename: String, + public val cacheKey: String, + public val localPath: String? = null, + public val huggingFace: HuggingFaceLocation? = null +) + +/** + * Request to resolve a local or remote artifact. + */ +public data class DataSourceRequest( + public val uri: String, + public val cachePolicy: CachePolicy = CachePolicy.Use, + public val expectedSha256: String? = null, + public val headers: Map = emptyMap() +) + +/** + * A resolved artifact. Remote artifacts may expose a [localPath] when they + * have been materialized into a platform cache. + */ +public class DataSourceArtifact( + public val request: DataSourceRequest, + public val parsedUri: ParsedDataSourceUri, + public val filename: String, + public val localPath: String?, + public val sizeBytes: Long?, + public val cacheHit: Boolean, + private val byteReader: suspend () -> ByteArray +) { + public suspend fun readBytes(): ByteArray = byteReader() +} + +/** + * Resolves source URIs into readable data artifacts. + */ +public interface DataSourceResolver { + public suspend fun resolve(request: DataSourceRequest): DataSourceArtifact +} + +public open class DataSourceException( + message: String, + cause: Throwable? = null +) : RuntimeException(message, cause) + +public class UnsupportedDataSourceUriException( + uri: String +) : DataSourceException("Unsupported data source URI: $uri") diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt new file mode 100644 index 00000000..06378277 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt @@ -0,0 +1,200 @@ +package sk.ainet.data.source + +/** + * Parses SKaiNET data source URIs. + * + * Supported forms: + * - `file:///absolute/path` + * - `/absolute/or/relative/path` + * - `https://host/path` + * - `hf+https://huggingface.co/org/repo/resolve/main/file` + * - `hf://org/repo@revision/path/to/file` + * - `hf://datasets/org/repo@revision/path/to/file` + */ +public object DataSourceUriParser { + public fun parse(uri: String): ParsedDataSourceUri { + val raw = uri.trim() + require(raw.isNotEmpty()) { "Data source URI must not be blank" } + + return when { + raw.startsWith(HF_HTTPS_PREFIX) -> parseHfHttps(raw) + raw.startsWith(HF_URI_PREFIX) -> parseHfUri(raw) + raw.startsWith(FILE_URI_PREFIX) -> parseFileUri(raw) + raw.startsWith(HTTPS_PREFIX) || raw.startsWith(HTTP_PREFIX) -> parseHttp(raw) + raw.contains("://") -> throw UnsupportedDataSourceUriException(raw) + else -> parsePlainFilePath(raw) + } + } + + private fun parseFileUri(raw: String): ParsedDataSourceUri { + val localPath = normalizeFileUriPath(raw.removePrefix(FILE_URI_PREFIX)) + val filename = extractFilename(localPath) + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.File, + transportUri = raw, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.File, localPath, filename), + localPath = localPath + ) + } + + private fun parsePlainFilePath(raw: String): ParsedDataSourceUri { + val filename = extractFilename(raw) + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.File, + transportUri = raw, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.File, raw, filename), + localPath = raw + ) + } + + private fun parseHttp(raw: String): ParsedDataSourceUri { + val filename = extractFilename(raw) + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.Http, + transportUri = raw, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.Http, raw, filename) + ) + } + + private fun parseHfHttps(raw: String): ParsedDataSourceUri { + val transportUri = raw.removePrefix("hf+") + val filename = extractFilename(transportUri) + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.HuggingFace, + transportUri = transportUri, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.HuggingFace, transportUri, filename), + huggingFace = HuggingFaceLocation( + repoType = HuggingFaceRepoType.Model, + repoId = null, + revision = null, + path = null + ) + ) + } + + private fun parseHfUri(raw: String): ParsedDataSourceUri { + val body = raw.removePrefix(HF_URI_PREFIX).trim('/') + val segments = body.split('/').filter { it.isNotBlank() } + require(segments.size >= 3) { + "hf:// URI must include repo owner, repo name, and file path: $raw" + } + + val (repoType, repoStart) = when (segments.first()) { + "models", "model" -> HuggingFaceRepoType.Model to 1 + "datasets", "dataset" -> HuggingFaceRepoType.Dataset to 1 + "spaces", "space" -> HuggingFaceRepoType.Space to 1 + else -> HuggingFaceRepoType.Model to 0 + } + require(segments.size - repoStart >= 3) { + "hf:// URI must include repo owner, repo name, and file path: $raw" + } + + val owner = segments[repoStart] + val repoAndRevision = segments[repoStart + 1] + val repoName = repoAndRevision.substringBefore('@') + val revision = repoAndRevision.substringAfter('@', "main") + val filePath = segments.drop(repoStart + 2).joinToString("/") + val repoId = "$owner/$repoName" + val prefix = when (repoType) { + HuggingFaceRepoType.Model -> "" + HuggingFaceRepoType.Dataset -> "datasets/" + HuggingFaceRepoType.Space -> "spaces/" + } + val transportUri = "https://huggingface.co/$prefix$repoId/resolve/$revision/$filePath" + val filename = extractFilename(filePath) + + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.HuggingFace, + transportUri = transportUri, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.HuggingFace, transportUri, filename), + huggingFace = HuggingFaceLocation( + repoType = repoType, + repoId = repoId, + revision = revision, + path = filePath + ) + ) + } + + private fun normalizeFileUriPath(path: String): String { + val withoutLocalhost = path.removePrefix("localhost/") + val normalized = if (withoutLocalhost.startsWith("/")) withoutLocalhost else "/$withoutLocalhost" + return percentDecode(normalized) + } + + private fun extractFilename(value: String): String { + val withoutFragment = value.substringBefore('#').substringBefore('?').trimEnd('/') + val filename = withoutFragment.substringAfterLast('/', missingDelimiterValue = withoutFragment) + return percentDecode(filename).ifBlank { "artifact" } + } + + private fun percentDecode(value: String): String { + val out = StringBuilder(value.length) + var i = 0 + while (i < value.length) { + val c = value[i] + if (c == '%' && i + 2 < value.length) { + val decoded = hexByte(value[i + 1], value[i + 2]) + if (decoded != null) { + out.append(decoded.toInt().toChar()) + i += 3 + continue + } + } + out.append(c) + i++ + } + return out.toString() + } + + private fun hexByte(high: Char, low: Char): Byte? { + val hi = high.digitToIntOrNull(16) ?: return null + val lo = low.digitToIntOrNull(16) ?: return null + return ((hi shl 4) or lo).toByte() + } + + private fun cacheKey(provider: DataSourceProvider, normalizedUri: String, filename: String): String { + val safeName = filename.map { ch -> + if (ch.isLetterOrDigit() || ch == '.' || ch == '-' || ch == '_') ch else '_' + }.joinToString("") + return "${provider.name.lowercase()}-${fnv1a32Hex(normalizedUri)}-$safeName" + } + + private fun fnv1a32Hex(value: String): String { + var hash = FNV_OFFSET + val bytes = value.encodeToByteArray() + for (byte in bytes) { + hash = hash xor (byte.toInt() and 0xff) + hash *= FNV_PRIME + } + return hash.toHex8() + } + + private fun Int.toHex8(): String { + val chars = CharArray(8) + for (i in chars.indices) { + val shift = (7 - i) * 4 + chars[i] = HEX[(this ushr shift) and 0x0f] + } + return chars.concatToString() + } + + private const val FILE_URI_PREFIX = "file://" + private const val HTTP_PREFIX = "http://" + private const val HTTPS_PREFIX = "https://" + private const val HF_HTTPS_PREFIX = "hf+https://" + private const val HF_URI_PREFIX = "hf://" + private const val FNV_OFFSET = -2128831035 + private const val FNV_PRIME = 16777619 + private val HEX: CharArray = "0123456789abcdef".toCharArray() +} diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt new file mode 100644 index 00000000..9076c5be --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt @@ -0,0 +1,80 @@ +package sk.ainet.data.source + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotEquals +import kotlin.test.assertNull + +class DataSourceUriParserTest { + @Test + fun parsesFileUri() { + val parsed = DataSourceUriParser.parse("file:///tmp/skainet/train-images.idx") + + assertEquals(DataSourceProvider.File, parsed.provider) + assertEquals("/tmp/skainet/train-images.idx", parsed.localPath) + assertEquals("train-images.idx", parsed.filename) + } + + @Test + fun parsesPlainPathAsFile() { + val parsed = DataSourceUriParser.parse("fixtures/mnist/train-labels.idx") + + assertEquals(DataSourceProvider.File, parsed.provider) + assertEquals("fixtures/mnist/train-labels.idx", parsed.localPath) + assertEquals("train-labels.idx", parsed.filename) + } + + @Test + fun parsesHttpUri() { + val parsed = DataSourceUriParser.parse("https://example.test/data/sample.csv?download=1") + + assertEquals(DataSourceProvider.Http, parsed.provider) + assertEquals("sample.csv", parsed.filename) + assertNull(parsed.huggingFace) + } + + @Test + fun parsesHuggingFaceHttpsProviderPrefix() { + val parsed = DataSourceUriParser.parse( + "hf+https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json" + ) + + assertEquals(DataSourceProvider.HuggingFace, parsed.provider) + assertEquals( + "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json", + parsed.transportUri + ) + assertEquals("tokenizer.json", parsed.filename) + } + + @Test + fun parsesHuggingFaceDatasetShorthand() { + val parsed = DataSourceUriParser.parse("hf://datasets/mnist/mnist@main/plain_text/train-00000.parquet") + + assertEquals(DataSourceProvider.HuggingFace, parsed.provider) + assertEquals(HuggingFaceRepoType.Dataset, parsed.huggingFace?.repoType) + assertEquals("mnist/mnist", parsed.huggingFace?.repoId) + assertEquals("main", parsed.huggingFace?.revision) + assertEquals("plain_text/train-00000.parquet", parsed.huggingFace?.path) + assertEquals( + "https://huggingface.co/datasets/mnist/mnist/resolve/main/plain_text/train-00000.parquet", + parsed.transportUri + ) + } + + @Test + fun cacheKeyDependsOnNormalizedUri() { + val first = DataSourceUriParser.parse("https://example.test/a.txt") + val second = DataSourceUriParser.parse("https://example.test/b.txt") + + assertNotEquals(first.cacheKey, second.cacheKey) + } + + @Test + fun rejectsUnknownSchemes() { + assertFailsWith { + DataSourceUriParser.parse("s3://bucket/object") + } + } +} From f7e22fe4a3ea23093d2c67c7e10e81bdf3a9e9fa Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 13:38:41 +0200 Subject: [PATCH 2/7] data: materialize JVM source artifacts --- .../ainet/data/source/DataSourceUriParser.kt | 15 +- .../data/source/DataSourceUriParserTest.kt | 9 + .../data/source/JvmDataSourceResolver.kt | 169 ++++++++++++++++++ .../data/source/JvmDataSourceResolverTest.kt | 162 +++++++++++++++++ 4 files changed, 350 insertions(+), 5 deletions(-) create mode 100644 skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt create mode 100644 skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt index 06378277..9281358a 100644 --- a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt @@ -19,7 +19,7 @@ public object DataSourceUriParser { return when { raw.startsWith(HF_HTTPS_PREFIX) -> parseHfHttps(raw) raw.startsWith(HF_URI_PREFIX) -> parseHfUri(raw) - raw.startsWith(FILE_URI_PREFIX) -> parseFileUri(raw) + raw.startsWith(FILE_URI_SCHEME) -> parseFileUri(raw) raw.startsWith(HTTPS_PREFIX) || raw.startsWith(HTTP_PREFIX) -> parseHttp(raw) raw.contains("://") -> throw UnsupportedDataSourceUriException(raw) else -> parsePlainFilePath(raw) @@ -27,7 +27,7 @@ public object DataSourceUriParser { } private fun parseFileUri(raw: String): ParsedDataSourceUri { - val localPath = normalizeFileUriPath(raw.removePrefix(FILE_URI_PREFIX)) + val localPath = normalizeFileUriPath(raw.removePrefix(FILE_URI_SCHEME)) val filename = extractFilename(localPath) return ParsedDataSourceUri( rawUri = raw, @@ -127,8 +127,13 @@ public object DataSourceUriParser { } private fun normalizeFileUriPath(path: String): String { - val withoutLocalhost = path.removePrefix("localhost/") - val normalized = if (withoutLocalhost.startsWith("/")) withoutLocalhost else "/$withoutLocalhost" + val normalized = when { + path.startsWith("//localhost/") -> path.removePrefix("//localhost") + path.startsWith("///") -> path.drop(2) + path.startsWith("//") -> path.drop(1) + path.startsWith("/") -> path + else -> "/$path" + } return percentDecode(normalized) } @@ -189,7 +194,7 @@ public object DataSourceUriParser { return chars.concatToString() } - private const val FILE_URI_PREFIX = "file://" + private const val FILE_URI_SCHEME = "file:" private const val HTTP_PREFIX = "http://" private const val HTTPS_PREFIX = "https://" private const val HF_HTTPS_PREFIX = "hf+https://" diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt index 9076c5be..7bcef66f 100644 --- a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt @@ -16,6 +16,15 @@ class DataSourceUriParserTest { assertEquals("train-images.idx", parsed.filename) } + @Test + fun parsesJvmFileUri() { + val parsed = DataSourceUriParser.parse("file:/tmp/skainet/train-images.idx") + + assertEquals(DataSourceProvider.File, parsed.provider) + assertEquals("/tmp/skainet/train-images.idx", parsed.localPath) + assertEquals("train-images.idx", parsed.filename) + } + @Test fun parsesPlainPathAsFile() { val parsed = DataSourceUriParser.parse("fixtures/mnist/train-labels.idx") diff --git a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt new file mode 100644 index 00000000..7cba502a --- /dev/null +++ b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt @@ -0,0 +1,169 @@ +package sk.ainet.data.source + +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.request.get +import io.ktor.client.request.header +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import java.io.File +import java.security.MessageDigest + +/** + * Fetches a remote URI into memory. Kept injectable so tests and applications + * can provide their own HTTP stack or policy layer. + */ +public fun interface RemoteDataSourceFetcher { + public suspend fun fetch(uri: String, headers: Map): ByteArray +} + +/** + * Ktor/CIO-backed remote fetcher for JVM data artifacts. + */ +public class KtorRemoteDataSourceFetcher( + private val client: HttpClient = HttpClient(CIO) { + expectSuccess = true + install(HttpTimeout) { + requestTimeoutMillis = 600_000 + connectTimeoutMillis = 60_000 + socketTimeoutMillis = 600_000 + } + } +) : RemoteDataSourceFetcher, AutoCloseable { + override suspend fun fetch(uri: String, headers: Map): ByteArray { + return client.get(uri) { + headers.forEach { (name, value) -> header(name, value) } + }.body() + } + + override fun close() { + client.close() + } +} + +/** + * JVM resolver for local files and cached remote artifacts. + */ +public class JvmDataSourceResolver( + private val cacheDir: File = defaultCacheDir(), + private val fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher() +) : DataSourceResolver { + override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact = withContext(Dispatchers.IO) { + val parsed = DataSourceUriParser.parse(request.uri) + when (parsed.provider) { + DataSourceProvider.File -> resolveFile(request, parsed) + DataSourceProvider.Http, DataSourceProvider.HuggingFace -> resolveRemote(request, parsed) + } + } + + private fun resolveFile( + request: DataSourceRequest, + parsed: ParsedDataSourceUri + ): DataSourceArtifact { + val path = parsed.localPath ?: throw DataSourceException("File source has no local path: ${request.uri}") + val file = File(path) + require(file.exists()) { "Data source file not found: ${file.absolutePath}" } + require(file.isFile) { "Data source path is not a file: ${file.absolutePath}" } + request.expectedSha256?.let { verifySha256(file.readBytes(), it, request.uri) } + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = file.absolutePath, + sizeBytes = file.length(), + cacheHit = true, + byteReader = { file.readBytes() } + ) + } + + private suspend fun resolveRemote( + request: DataSourceRequest, + parsed: ParsedDataSourceUri + ): DataSourceArtifact { + val target = File(cacheDir, parsed.cacheKey) + val canUseCache = request.cachePolicy == CachePolicy.Use || request.cachePolicy == CachePolicy.Offline + if (canUseCache && target.exists() && target.isFile) { + request.expectedSha256?.let { verifySha256(target.readBytes(), it, request.uri) } + return cachedArtifact(request, parsed, target, cacheHit = true) + } + + if (request.cachePolicy == CachePolicy.Offline) { + throw DataSourceException("No cached artifact available for offline source: ${request.uri}") + } + + val bytes = fetcher.fetch(parsed.transportUri, requestHeaders(request, parsed)) + request.expectedSha256?.let { verifySha256(bytes, it, request.uri) } + + if (request.cachePolicy == CachePolicy.Bypass) { + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = null, + sizeBytes = bytes.size.toLong(), + cacheHit = false, + byteReader = { bytes } + ) + } + + cacheDir.mkdirs() + val temp = File(cacheDir, "${parsed.cacheKey}.tmp") + temp.writeBytes(bytes) + if (!temp.renameTo(target)) { + temp.copyTo(target, overwrite = true) + temp.delete() + } + return cachedArtifact(request, parsed, target, cacheHit = false) + } + + private fun cachedArtifact( + request: DataSourceRequest, + parsed: ParsedDataSourceUri, + target: File, + cacheHit: Boolean + ): DataSourceArtifact { + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = target.absolutePath, + sizeBytes = target.length(), + cacheHit = cacheHit, + byteReader = { target.readBytes() } + ) + } + + private fun requestHeaders( + request: DataSourceRequest, + parsed: ParsedDataSourceUri + ): Map { + if (parsed.provider != DataSourceProvider.HuggingFace) return request.headers + if (request.headers.keys.any { it.equals("Authorization", ignoreCase = true) }) return request.headers + val token = System.getenv("HF_TOKEN") + ?.takeIf { it.isNotBlank() } + ?: System.getenv("HUGGING_FACE_HUB_TOKEN")?.takeIf { it.isNotBlank() } + ?: return request.headers + return request.headers + ("Authorization" to "Bearer $token") + } + + private fun verifySha256(bytes: ByteArray, expected: String, uri: String) { + val actual = MessageDigest.getInstance("SHA-256") + .digest(bytes) + .joinToString("") { byte -> "%02x".format(byte) } + if (!actual.equals(expected, ignoreCase = true)) { + throw DataSourceException( + "SHA-256 mismatch for $uri: expected ${expected.lowercase()}, actual $actual" + ) + } + } + + public companion object { + public fun defaultCacheDir(): File { + val userHome = System.getProperty("user.home")?.takeIf { it.isNotBlank() } + val base = userHome ?: System.getProperty("java.io.tmpdir") + return File(base, ".cache/skainet/data") + } + } +} diff --git a/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt new file mode 100644 index 00000000..5b44523e --- /dev/null +++ b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt @@ -0,0 +1,162 @@ +package sk.ainet.data.source + +import kotlinx.coroutines.test.runTest +import java.nio.file.Files +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class JvmDataSourceResolverTest { + @Test + fun resolvesLocalFileUri() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val file = root.resolve("sample.txt") + file.writeText("hello") + val resolver = JvmDataSourceResolver(cacheDir = root.resolve("cache")) + + val artifact = resolver.resolve(DataSourceRequest(file.toURI().toString())) + + assertEquals("sample.txt", artifact.filename) + assertEquals(file.absolutePath, artifact.localPath) + assertTrue(artifact.cacheHit) + assertContentEquals("hello".encodeToByteArray(), artifact.readBytes()) + } finally { + root.deleteRecursively() + } + } + + @Test + fun cachesRemoteArtifacts() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = FakeFetcher("first".encodeToByteArray()) + val resolver = JvmDataSourceResolver(cacheDir = root.resolve("cache"), fetcher = fetcher) + val request = DataSourceRequest( + "hf+https://huggingface.co/example/model/resolve/main/config.json" + ) + + val first = resolver.resolve(request) + val second = resolver.resolve(request) + + assertEquals(1, fetcher.calls) + assertFalse(first.cacheHit) + assertTrue(second.cacheHit) + assertNotNull(second.localPath) + assertContentEquals("first".encodeToByteArray(), second.readBytes()) + } finally { + root.deleteRecursively() + } + } + + @Test + fun refreshFetchesAgain() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = QueueFetcher( + "old".encodeToByteArray(), + "new".encodeToByteArray() + ) + val resolver = JvmDataSourceResolver(cacheDir = root.resolve("cache"), fetcher = fetcher) + val uri = "https://example.test/data.bin" + + resolver.resolve(DataSourceRequest(uri)).readBytes() + val refreshed = resolver.resolve(DataSourceRequest(uri, cachePolicy = CachePolicy.Refresh)) + + assertEquals(2, fetcher.calls) + assertContentEquals("new".encodeToByteArray(), refreshed.readBytes()) + } finally { + root.deleteRecursively() + } + } + + @Test + fun offlineFailsWhenCacheIsMissing() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val resolver = JvmDataSourceResolver(cacheDir = root.resolve("cache"), fetcher = FakeFetcher(ByteArray(0))) + + assertFailsWith { + resolver.resolve( + DataSourceRequest( + uri = "https://example.test/missing.bin", + cachePolicy = CachePolicy.Offline + ) + ) + } + } finally { + root.deleteRecursively() + } + } + + @Test + fun bypassDoesNotWriteCache() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = FakeFetcher("bytes".encodeToByteArray()) + val cacheDir = root.resolve("cache") + val resolver = JvmDataSourceResolver(cacheDir = cacheDir, fetcher = fetcher) + + val artifact = resolver.resolve( + DataSourceRequest("https://example.test/data.bin", cachePolicy = CachePolicy.Bypass) + ) + + assertEquals(1, fetcher.calls) + assertEquals(null, artifact.localPath) + assertFalse(cacheDir.exists()) + } finally { + root.deleteRecursively() + } + } + + @Test + fun verifiesSha256() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val resolver = JvmDataSourceResolver( + cacheDir = root.resolve("cache"), + fetcher = FakeFetcher("payload".encodeToByteArray()) + ) + + assertFailsWith { + resolver.resolve( + DataSourceRequest( + uri = "https://example.test/payload.bin", + expectedSha256 = "0000" + ) + ) + } + } finally { + root.deleteRecursively() + } + } +} + +private class FakeFetcher( + private val bytes: ByteArray +) : RemoteDataSourceFetcher { + var calls: Int = 0 + private set + + override suspend fun fetch(uri: String, headers: Map): ByteArray { + calls++ + return bytes + } +} + +private class QueueFetcher( + private vararg val responses: ByteArray +) : RemoteDataSourceFetcher { + var calls: Int = 0 + private set + + override suspend fun fetch(uri: String, headers: Map): ByteArray { + val index = calls.coerceAtMost(responses.lastIndex) + calls++ + return responses[index] + } +} From 9fefcd5410f986f0ef572ee5c3c7935a1209363b Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 13:41:43 +0200 Subject: [PATCH 3/7] data: route simple loaders through sources --- .../skainet-data-simple/build.gradle.kts | 1 + .../sk/ainet/data/cifar10/CIFAR10Data.kt | 3 +- .../data/fashionmnist/FashionMNISTData.kt | 6 +- .../fashionmnist/FashionMNISTLoaderCommon.kt | 8 +- .../kotlin/sk/ainet/data/mnist/MNISTData.kt | 8 +- .../sk/ainet/data/mnist/MNISTLoaderCommon.kt | 8 +- .../sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt | 72 +++--------- .../fashionmnist/FashionMNISTLoaderJvm.kt | 103 ++++-------------- .../sk/ainet/data/mnist/MNISTLoaderJvm.kt | 103 ++++-------------- .../io/data/cifar10/CIFAR10LoaderTest.kt | 3 +- .../fashionmnist/FashionMNISTLoaderTest.kt | 4 +- .../sk/ainet/io/data/mnist/MNISTLoaderTest.kt | 28 ++++- 12 files changed, 110 insertions(+), 237 deletions(-) diff --git a/skainet-data/skainet-data-simple/build.gradle.kts b/skainet-data/skainet-data-simple/build.gradle.kts index 562bedd6..9bd671f0 100644 --- a/skainet-data/skainet-data-simple/build.gradle.kts +++ b/skainet-data/skainet-data-simple/build.gradle.kts @@ -62,6 +62,7 @@ kotlin { } jvmMain.dependencies { + implementation(project(":skainet-data:skainet-data-source")) implementation(libs.ktor.client.cio) implementation(libs.ktor.client.plugins) implementation(libs.ktor.client.logging) diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt index f2a1e106..9344e150 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt @@ -144,7 +144,8 @@ public data class CIFAR10Dataset( */ public data class CIFAR10LoaderConfig( val cacheDir: String = "cifar10-data", - val useCache: Boolean = true + val useCache: Boolean = true, + val archiveUri: String = CIFAR10Constants.DOWNLOAD_URL ) /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt index b9227418..bdc0e99e 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt @@ -146,7 +146,11 @@ public data class FashionMNISTDataset( */ public data class FashionMNISTLoaderConfig( val cacheDir: String = "fashion-mnist-data", - val useCache: Boolean = true + val useCache: Boolean = true, + val trainImagesUri: String = FashionMNISTConstants.TRAIN_IMAGES_URL, + val trainLabelsUri: String = FashionMNISTConstants.TRAIN_LABELS_URL, + val testImagesUri: String = FashionMNISTConstants.TEST_IMAGES_URL, + val testLabelsUri: String = FashionMNISTConstants.TEST_LABELS_URL ) /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderCommon.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderCommon.kt index 45888031..9fff81fd 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderCommon.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderCommon.kt @@ -16,11 +16,11 @@ public abstract class FashionMNISTLoaderCommon(public val config: FashionMNISTLo */ override suspend fun loadTrainingData(): FashionMNISTDataset { val imagesBytes = downloadAndCacheFile( - FashionMNISTConstants.TRAIN_IMAGES_URL, + config.trainImagesUri, FashionMNISTConstants.TRAIN_IMAGES_FILENAME ) val labelsBytes = downloadAndCacheFile( - FashionMNISTConstants.TRAIN_LABELS_URL, + config.trainLabelsUri, FashionMNISTConstants.TRAIN_LABELS_FILENAME ) @@ -34,11 +34,11 @@ public abstract class FashionMNISTLoaderCommon(public val config: FashionMNISTLo */ override suspend fun loadTestData(): FashionMNISTDataset { val imagesBytes = downloadAndCacheFile( - FashionMNISTConstants.TEST_IMAGES_URL, + config.testImagesUri, FashionMNISTConstants.TEST_IMAGES_FILENAME ) val labelsBytes = downloadAndCacheFile( - FashionMNISTConstants.TEST_LABELS_URL, + config.testLabelsUri, FashionMNISTConstants.TEST_LABELS_FILENAME ) diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt index d120cdd6..c4bfa76a 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt @@ -124,7 +124,11 @@ public data class MNISTDataset( */ public data class MNISTLoaderConfig( val cacheDir: String = "mnist-data", - val useCache: Boolean = true + val useCache: Boolean = true, + val trainImagesUri: String = MNISTConstants.TRAIN_IMAGES_URL, + val trainLabelsUri: String = MNISTConstants.TRAIN_LABELS_URL, + val testImagesUri: String = MNISTConstants.TEST_IMAGES_URL, + val testLabelsUri: String = MNISTConstants.TEST_LABELS_URL ) /** @@ -164,4 +168,4 @@ public interface MNISTLoader { * @return The MNIST test dataset. */ public suspend fun loadTestData(): MNISTDataset -} \ No newline at end of file +} diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTLoaderCommon.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTLoaderCommon.kt index 4c334d31..66ef8085 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTLoaderCommon.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTLoaderCommon.kt @@ -14,11 +14,11 @@ public abstract class MNISTLoaderCommon(public val config: MNISTLoaderConfig) : */ override suspend fun loadTrainingData(): MNISTDataset { val imagesBytes = downloadAndCacheFile( - MNISTConstants.TRAIN_IMAGES_URL, + config.trainImagesUri, MNISTConstants.TRAIN_IMAGES_FILENAME ) val labelsBytes = downloadAndCacheFile( - MNISTConstants.TRAIN_LABELS_URL, + config.trainLabelsUri, MNISTConstants.TRAIN_LABELS_FILENAME ) @@ -32,11 +32,11 @@ public abstract class MNISTLoaderCommon(public val config: MNISTLoaderConfig) : */ override suspend fun loadTestData(): MNISTDataset { val imagesBytes = downloadAndCacheFile( - MNISTConstants.TEST_IMAGES_URL, + config.testImagesUri, MNISTConstants.TEST_IMAGES_FILENAME ) val labelsBytes = downloadAndCacheFile( - MNISTConstants.TEST_LABELS_URL, + config.testLabelsUri, MNISTConstants.TEST_LABELS_FILENAME ) diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt index 82212e3f..c6ddcce5 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt @@ -1,18 +1,13 @@ package sk.ainet.data.cifar10 -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.logging.Logging -import io.ktor.client.plugins.HttpTimeout -import io.ktor.client.request.get -import io.ktor.client.statement.HttpResponse -import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceRequest +import sk.ainet.data.source.JvmDataSourceResolver +import java.io.ByteArrayInputStream import java.io.File -import java.io.FileInputStream import java.io.FileOutputStream -import java.io.ByteArrayInputStream import java.util.zip.GZIPInputStream /** @@ -23,6 +18,7 @@ import java.util.zip.GZIPInputStream * @property config The configuration for the CIFAR-10 loader. */ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon(config) { + private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources")) /** * Downloads the CIFAR-10 archive and extracts the specified batch file. @@ -45,21 +41,16 @@ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon return@withContext batchFile.readBytes() } - // Check if we need to download and extract the archive + // Check if we need to resolve and extract the archive if (!extractedDir.exists() || !config.useCache) { - val archiveFile = File(cacheDir, CIFAR10Constants.ARCHIVE_FILENAME) - - // Download if not cached - if (!archiveFile.exists() || !config.useCache) { - println("Downloading CIFAR-10 archive: ${CIFAR10Constants.DOWNLOAD_URL}") - downloadFile(CIFAR10Constants.DOWNLOAD_URL, archiveFile.path) - } else { - println("Using cached archive: ${archiveFile.path}") - } - - // Extract the archive + val archive = resolver.resolve( + DataSourceRequest( + uri = config.archiveUri, + cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh + ) + ) println("Extracting CIFAR-10 archive...") - extractTarGz(archiveFile.path, cacheDir.path) + extractTarGz(archive.readBytes(), cacheDir.path) } if (!batchFile.exists()) { @@ -69,48 +60,17 @@ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon return@withContext batchFile.readBytes() } - /** - * Downloads a file from a URL. - * - * @param url The URL to download from. - * @param outputPath The path to save the file to. - */ - private suspend fun downloadFile(url: String, outputPath: String) { - val client = HttpClient(CIO) { - install(Logging) - - // Configure timeout for large files (CIFAR-10 is ~170MB) - install(HttpTimeout) { - requestTimeoutMillis = 600000 // 10 minutes - connectTimeoutMillis = 60000 // 60 seconds - socketTimeoutMillis = 600000 // 10 minutes - } - } - - try { - val file = File(outputPath) - - val httpResponse: HttpResponse = client.get(url) - val responseBody: ByteArray = httpResponse.body() - file.writeBytes(responseBody) - - println("File saved to ${file.path} (${responseBody.size} bytes)") - } finally { - client.close() - } - } - /** * Extracts a .tar.gz archive using a simple TAR parser. * - * @param archivePath The path to the .tar.gz file. + * @param archiveBytes The bytes of the .tar.gz file. * @param outputDir The directory to extract files to. */ - private fun extractTarGz(archivePath: String, outputDir: String) { + private fun extractTarGz(archiveBytes: ByteArray, outputDir: String) { val outputDirFile = File(outputDir) // First, decompress gzip to get the tar content - val tarBytes = GZIPInputStream(FileInputStream(archivePath)).use { gzipIn -> + val tarBytes = GZIPInputStream(ByteArrayInputStream(archiveBytes)).use { gzipIn -> gzipIn.readBytes() } diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt index 2444a779..332dfeb2 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt @@ -1,18 +1,13 @@ package sk.ainet.data.fashionmnist -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.logging.Logging -import io.ktor.client.plugins.HttpTimeout -import io.ktor.client.request.get -import io.ktor.client.statement.HttpResponse -import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext -import java.io.File -import java.io.FileInputStream -import java.io.FileOutputStream +import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceRequest +import sk.ainet.data.source.JvmDataSourceResolver +import java.io.ByteArrayInputStream import java.util.zip.GZIPInputStream +import java.io.File /** * JVM implementation of the Fashion-MNIST loader. @@ -20,91 +15,32 @@ import java.util.zip.GZIPInputStream * @property config The configuration for the Fashion-MNIST loader. */ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMNISTLoaderCommon(config) { + private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources")) /** - * Downloads and caches a file. + * Resolves, caches, and decompresses a file when needed. * * @param url The URL to download from. * @param filename The name of the file to save. * @return The bytes of the decompressed file. */ override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.IO) { - val cacheDir = File(config.cacheDir) - if (!cacheDir.exists()) { - cacheDir.mkdirs() - } - - val gzipFile = File(cacheDir, filename) - val decompressedFile = File(cacheDir, filename.removeSuffix(".gz")) - - // Check if the decompressed file already exists in cache - if (config.useCache && decompressedFile.exists()) { - println("Using cached file: ${decompressedFile.path}") - return@withContext decompressedFile.readBytes() - } - - // Check if the gzip file already exists in cache - if (!gzipFile.exists() || !config.useCache) { - println("Downloading Fashion-MNIST file: $url") - downloadFile(url, gzipFile.path) - } else { - println("Using cached gzip file: ${gzipFile.path}") - } - - // Decompress the gzip file - println("Decompressing file: ${gzipFile.path}") - decompressGzipFile(gzipFile.path, decompressedFile.path) - - return@withContext decompressedFile.readBytes() + val artifact = resolver.resolve( + DataSourceRequest( + uri = url, + cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh + ) + ) + return@withContext maybeGunzip(artifact.readBytes()) } - /** - * Downloads a file from a URL. - * - * @param url The URL to download from. - * @param outputPath The path to save the file to. - */ - private suspend fun downloadFile(url: String, outputPath: String) { - val client = HttpClient(CIO) { - install(Logging) - - // Configure timeout for large files - install(HttpTimeout) { - requestTimeoutMillis = 60000 // 60 seconds - connectTimeoutMillis = 60000 // 60 seconds - socketTimeoutMillis = 60000 // 60 seconds - } - } - - try { - val file = File(outputPath) - - val httpResponse: HttpResponse = client.get(url) - val responseBody: ByteArray = httpResponse.body() - file.writeBytes(responseBody) - - println("File saved to ${file.path}") - } finally { - client.close() - } + private fun maybeGunzip(bytes: ByteArray): ByteArray { + if (!bytes.isGzip()) return bytes + return GZIPInputStream(ByteArrayInputStream(bytes)).use { it.readBytes() } } - /** - * Decompresses a gzip file. - * - * @param gzipFilePath The path to the gzip file. - * @param outputFilePath The path to save the decompressed file to. - */ - private fun decompressGzipFile(gzipFilePath: String, outputFilePath: String) { - GZIPInputStream(FileInputStream(gzipFilePath)).use { gzipInputStream -> - FileOutputStream(outputFilePath).use { outputStream -> - val buffer = ByteArray(1024) - var len: Int - while (gzipInputStream.read(buffer).also { len = it } > 0) { - outputStream.write(buffer, 0, len) - } - } - } + private fun ByteArray.isGzip(): Boolean { + return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() } public companion object { @@ -137,4 +73,5 @@ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMN return FashionMNISTLoaderJvm(config) } } + } diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt index b6bcb9aa..330197d2 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt @@ -1,18 +1,13 @@ package sk.ainet.data.mnist -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.logging.Logging -import io.ktor.client.plugins.HttpTimeout -import io.ktor.client.request.get -import io.ktor.client.statement.HttpResponse -import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext -import java.io.File -import java.io.FileInputStream -import java.io.FileOutputStream +import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceRequest +import sk.ainet.data.source.JvmDataSourceResolver +import java.io.ByteArrayInputStream import java.util.zip.GZIPInputStream +import java.io.File /** * JVM implementation of the MNIST loader. @@ -20,91 +15,32 @@ import java.util.zip.GZIPInputStream * @property config The configuration for the MNIST loader. */ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(config) { + private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources")) /** - * Downloads and caches a file. + * Resolves, caches, and decompresses a file when needed. * * @param url The URL to download from. * @param filename The name of the file to save. * @return The bytes of the decompressed file. */ override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.IO) { - val cacheDir = File(config.cacheDir) - if (!cacheDir.exists()) { - cacheDir.mkdirs() - } - - val gzipFile = File(cacheDir, filename) - val decompressedFile = File(cacheDir, filename.removeSuffix(".gz")) - - // Check if the decompressed file already exists in cache - if (config.useCache && decompressedFile.exists()) { - println("Using cached file: ${decompressedFile.path}") - return@withContext decompressedFile.readBytes() - } - - // Check if the gzip file already exists in cache - if (!gzipFile.exists() || !config.useCache) { - println("Downloading file: $url") - downloadFile(url, gzipFile.path) - } else { - println("Using cached gzip file: ${gzipFile.path}") - } - - // Decompress the gzip file - println("Decompressing file: ${gzipFile.path}") - decompressGzipFile(gzipFile.path, decompressedFile.path) - - return@withContext decompressedFile.readBytes() + val artifact = resolver.resolve( + DataSourceRequest( + uri = url, + cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh + ) + ) + return@withContext maybeGunzip(artifact.readBytes()) } - /** - * Downloads a file from a URL. - * - * @param url The URL to download from. - * @param outputPath The path to save the file to. - */ - private suspend fun downloadFile(url: String, outputPath: String) { - val client = HttpClient(CIO) { - install(Logging) - - // Configure timeout for large files - install(HttpTimeout) { - requestTimeoutMillis = 300000 // 5 minutes - connectTimeoutMillis = 60000 // 60 seconds - socketTimeoutMillis = 300000 // 5 minutes - } - } - - try { - val file = File(outputPath) - - val httpResponse: HttpResponse = client.get(url) - val responseBody: ByteArray = httpResponse.body() - file.writeBytes(responseBody) - - println("File saved to ${file.path}") - } finally { - client.close() - } + private fun maybeGunzip(bytes: ByteArray): ByteArray { + if (!bytes.isGzip()) return bytes + return GZIPInputStream(ByteArrayInputStream(bytes)).use { it.readBytes() } } - /** - * Decompresses a gzip file. - * - * @param gzipFilePath The path to the gzip file. - * @param outputFilePath The path to save the decompressed file to. - */ - private fun decompressGzipFile(gzipFilePath: String, outputFilePath: String) { - GZIPInputStream(FileInputStream(gzipFilePath)).use { gzipInputStream -> - FileOutputStream(outputFilePath).use { outputStream -> - val buffer = ByteArray(1024) - var len: Int - while (gzipInputStream.read(buffer).also { len = it } > 0) { - outputStream.write(buffer, 0, len) - } - } - } + private fun ByteArray.isGzip(): Boolean { + return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() } public companion object { @@ -137,4 +73,5 @@ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(confi return MNISTLoaderJvm(config) } } + } diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt index 06b23b39..541516b3 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt @@ -60,7 +60,8 @@ class CIFAR10LoaderTest { fun testLoaderConfiguration() { val config = CIFAR10LoaderConfig( cacheDir = "custom-cache-dir", - useCache = false + useCache = false, + archiveUri = "hf+https://huggingface.co/datasets/cifar10/resolve/main/cifar-10-binary.tar.gz" ) val loader = createCIFAR10Loader(config) diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt index 5f80dc31..30135890 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt @@ -60,7 +60,9 @@ class FashionMNISTLoaderTest { fun testLoaderConfiguration() { val config = FashionMNISTLoaderConfig( cacheDir = "custom-cache-dir", - useCache = false + useCache = false, + trainImagesUri = "file:///datasets/fashion-mnist/train-images", + trainLabelsUri = "hf+https://huggingface.co/datasets/zalando-datasets/fashion_mnist/resolve/main/train-labels" ) val loader = createFashionMNISTLoader(config) diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt index 82e04d75..94796645 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt @@ -6,6 +6,7 @@ import sk.ainet.data.mnist.MNISTImage import sk.ainet.data.mnist.MNISTLoaderConfig import sk.ainet.data.mnist.MNISTLoaderFactory import sk.ainet.data.mnist.MNISTLoaderCommon +import java.nio.file.Files import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNotNull @@ -60,12 +61,37 @@ class MNISTLoaderTest { fun testLoaderConfiguration() { val config = MNISTLoaderConfig( cacheDir = "custom-cache-dir", - useCache = false + useCache = false, + trainImagesUri = "file:///datasets/mnist/train-images", + trainLabelsUri = "hf+https://huggingface.co/datasets/mnist/mnist/resolve/main/train-labels" ) val loader = MNISTLoaderFactory.create(config) assertNotNull(loader) } + + @Test + fun testJvmLoaderReadsConfiguredFileUris() = runBlocking { + val root = Files.createTempDirectory("skainet-mnist-loader-test").toFile() + try { + val trainImages = root.resolve("train-images.idx") + val trainLabels = root.resolve("train-labels.idx") + trainImages.writeBytes(TRAINING_IMAGES_BYTES) + trainLabels.writeBytes(TRAINING_LABELS_BYTES) + val config = MNISTLoaderConfig( + cacheDir = root.resolve("cache").absolutePath, + useCache = false, + trainImagesUri = trainImages.toURI().toString(), + trainLabelsUri = trainLabels.toURI().toString() + ) + + val dataset = MNISTLoaderFactory.create(config).loadTrainingData() + + assertEquals(EXPECTED_TRAINING_DATA, dataset.images) + } finally { + root.deleteRecursively() + } + } } private fun createFakeLoader(config: MNISTLoaderConfig = MNISTLoaderConfig()): FakeMNISTLoader { From 8f4aaeee3a97d53ab3da43323d4cc10f647b5c8b Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 13:43:08 +0200 Subject: [PATCH 4/7] docs: explain data source URIs --- README.md | 1 + build.gradle.kts | 3 +- docs/modules/ROOT/nav.adoc | 1 + .../data-sources-getting-started.adoc | 117 ++++++++++++++++++ 4 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc diff --git a/README.md b/README.md index 59beabfe..56eda582 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,7 @@ Runnable examples: ### Data and I/O - Built-in loaders: MNIST, Fashion-MNIST, CIFAR-10 +- URI-backed data sources: `file://`, `https://`, `hf+https://`, and `hf://...` - Formats: GGUF, ONNX, SafeTensors, JSON, Image (JPEG, PNG) - Type-safe transform DSL: resize, crop, normalize, toTensor diff --git a/build.gradle.kts b/build.gradle.kts index 771b2133..d76ff097 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -151,6 +151,7 @@ dependencies { // skainet-data dokka(project(":skainet-data:skainet-data-api")) + dokka(project(":skainet-data:skainet-data-source")) dokka(project(":skainet-data:skainet-data-transform")) dokka(project(":skainet-data:skainet-data-simple")) dokka(project(":skainet-data:skainet-data-media")) @@ -178,4 +179,4 @@ tasks.register("bundleDokkaIntoSite") { dependsOn("dokkaGenerate") from(layout.buildDirectory.dir("dokka/html")) into(layout.projectDirectory.dir("docs/build/site/api")) -} \ No newline at end of file +} diff --git a/docs/modules/ROOT/nav.adoc b/docs/modules/ROOT/nav.adoc index 1c8ed540..e6a714cd 100644 --- a/docs/modules/ROOT/nav.adoc +++ b/docs/modules/ROOT/nav.adoc @@ -5,6 +5,7 @@ * Tutorials ** xref:tutorials/kotlin-getting-started.adoc[Kotlin getting started] ** xref:tutorials/java-getting-started.adoc[Java getting started] +** xref:tutorials/data-sources-getting-started.adoc[Data sources and Hugging Face] ** xref:tutorials/image-data-getting-started.adoc[Image and data API] ** xref:tutorials/hlo-getting-started.adoc[StableHLO getting started] ** xref:tutorials/minerva-getting-started.adoc[Minerva getting started] diff --git a/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc new file mode 100644 index 00000000..b4001501 --- /dev/null +++ b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc @@ -0,0 +1,117 @@ +== Data sources and Hugging Face + +SKaiNET separates artifact resolution from dataset parsing and preprocessing. +Use `skainet-data-source` when a dataset, tokenizer, model sidecar, or fixture +can live either on disk or behind a remote URI. + +[cols="1,3",options="header"] +|=== +| URI form | Meaning +| `file:///path/to/file` +| Read a local file. + +| `https://host/path/file` +| Download and cache a generic remote artifact. + +| `hf+https://huggingface.co/org/repo/resolve/main/file` +| Treat a Hugging Face resolve URL as a Hugging Face artifact. + +| `hf://org/repo@main/path/file` +| Expand to a Hugging Face model repository resolve URL. + +| `hf://datasets/org/repo@main/path/file` +| Expand to a Hugging Face dataset repository resolve URL. +|=== + +=== Add the modules + +For JVM consumers, add the source module beside the data loaders you use: + +[source,kotlin] +---- +dependencies { + implementation(platform("sk.ainet:skainet-bom:0.32.4")) + + implementation("sk.ainet.core:skainet-data-source-jvm") + implementation("sk.ainet.core:skainet-data-simple-jvm") +} +---- + +=== Resolve one artifact + +`JvmDataSourceResolver` materializes remote artifacts into a cache and returns +a `DataSourceArtifact` that can be read as bytes. Public Hugging Face files do +not need credentials. Private files can use an `Authorization` header, or the +JVM resolver will read `HF_TOKEN` / `HUGGING_FACE_HUB_TOKEN` from the +environment when the URI provider is Hugging Face. + +[source,kotlin] +---- +import sk.ainet.data.source.DataSourceRequest +import sk.ainet.data.source.JvmDataSourceResolver + +val resolver = JvmDataSourceResolver() +val artifact = resolver.resolve( + DataSourceRequest( + uri = "hf+https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json" + ) +) + +println(artifact.filename) +println(artifact.localPath) + +val bytes = artifact.readBytes() +---- + +=== Use sources with built-in loaders + +MNIST and Fashion-MNIST expose per-file URI overrides. CIFAR-10 exposes an +archive URI override. Defaults still point to the historical public dataset +locations, so existing code keeps working. + +[source,kotlin] +---- +import sk.ainet.data.mnist.MNIST +import sk.ainet.data.mnist.MNISTLoaderConfig + +val train = MNIST.loadTrain( + MNISTLoaderConfig( + trainImagesUri = "file:///datasets/mnist/train-images-idx3-ubyte", + trainLabelsUri = "hf+https://huggingface.co/your-org/mnist-idx/resolve/main/train-labels-idx1-ubyte.gz" + ) +) + +val batches = train.batchIterator(batchSize = 64) +---- + +=== Cache behavior + +Use `CachePolicy.Use` for normal operation, `Refresh` to re-download, +`Offline` to require a cached copy, and `Bypass` to avoid writing the cache. +Built-in JVM loaders map `useCache = true` to `Use` and `useCache = false` +to `Refresh`. + +[source,kotlin] +---- +import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceRequest + +val refreshed = resolver.resolve( + DataSourceRequest( + uri = "hf://datasets/your-org/your-dataset@main/data/train-00000.parquet", + cachePolicy = CachePolicy.Refresh + ) +) +---- + +=== Keep preprocessing separate + +After bytes are parsed into a dataset, continue using the existing transform +DSL for image/tensor preprocessing: + +[source,kotlin] +---- +import sk.ainet.data.transform.mnistPreprocessing + +val preprocessing = mnistPreprocessing(ctx) +---- From 82156a1492cdbbda0e4edbec302ad9de2c86df70 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 14:22:21 +0200 Subject: [PATCH 5/7] data: share source resolver core --- .../sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt | 26 +-- .../data/common/JvmDatasetSourceReader.kt | 40 +++++ .../fashionmnist/FashionMNISTLoaderJvm.kt | 30 +--- .../sk/ainet/data/mnist/MNISTLoaderJvm.kt | 30 +--- .../skainet-data-source/build.gradle.kts | 4 +- .../data/source/DefaultDataSourceResolver.kt | 138 +++++++++++++++ .../source/DefaultDataSourceResolverTest.kt | 164 ++++++++++++++++++ .../data/source/JvmDataSourceResolver.kt | 144 +++++---------- 8 files changed, 404 insertions(+), 172 deletions(-) create mode 100644 skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt create mode 100644 skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt create mode 100644 skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt index c6ddcce5..10fc06a3 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt @@ -2,13 +2,10 @@ package sk.ainet.data.cifar10 import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext -import sk.ainet.data.source.CachePolicy -import sk.ainet.data.source.DataSourceRequest -import sk.ainet.data.source.JvmDataSourceResolver -import java.io.ByteArrayInputStream +import sk.ainet.data.common.JvmDatasetSourceReader +import sk.ainet.data.common.gunzip import java.io.File import java.io.FileOutputStream -import java.util.zip.GZIPInputStream /** * JVM implementation of the CIFAR-10 loader. @@ -18,7 +15,7 @@ import java.util.zip.GZIPInputStream * @property config The configuration for the CIFAR-10 loader. */ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon(config) { - private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources")) + private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache) /** * Downloads the CIFAR-10 archive and extracts the specified batch file. @@ -43,14 +40,8 @@ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon // Check if we need to resolve and extract the archive if (!extractedDir.exists() || !config.useCache) { - val archive = resolver.resolve( - DataSourceRequest( - uri = config.archiveUri, - cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh - ) - ) println("Extracting CIFAR-10 archive...") - extractTarGz(archive.readBytes(), cacheDir.path) + extractTarGz(sources.read(config.archiveUri), cacheDir.path) } if (!batchFile.exists()) { @@ -68,14 +59,7 @@ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon */ private fun extractTarGz(archiveBytes: ByteArray, outputDir: String) { val outputDirFile = File(outputDir) - - // First, decompress gzip to get the tar content - val tarBytes = GZIPInputStream(ByteArrayInputStream(archiveBytes)).use { gzipIn -> - gzipIn.readBytes() - } - - // Parse the TAR archive - extractTar(tarBytes, outputDirFile) + extractTar(archiveBytes.gunzip(), outputDirFile) } /** diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt new file mode 100644 index 00000000..29ce6bbf --- /dev/null +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt @@ -0,0 +1,40 @@ +package sk.ainet.data.common + +import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceRequest +import sk.ainet.data.source.JvmDataSourceResolver +import java.io.ByteArrayInputStream +import java.io.File +import java.util.zip.GZIPInputStream + +internal class JvmDatasetSourceReader( + cacheDir: String, + useCache: Boolean +) { + private val resolver = JvmDataSourceResolver(File(cacheDir, "sources")) + private val cachePolicy = if (useCache) CachePolicy.Use else CachePolicy.Refresh + + suspend fun read(uri: String): ByteArray { + val artifact = resolver.resolve( + DataSourceRequest( + uri = uri, + cachePolicy = cachePolicy + ) + ) + return artifact.readBytes() + } + + suspend fun readGzipDecoded(uri: String): ByteArray = read(uri).gunzipIfNeeded() +} + +internal fun ByteArray.gunzip(): ByteArray { + return GZIPInputStream(ByteArrayInputStream(this)).use { it.readBytes() } +} + +internal fun ByteArray.gunzipIfNeeded(): ByteArray { + return if (isGzip()) gunzip() else this +} + +private fun ByteArray.isGzip(): Boolean { + return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() +} diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt index 332dfeb2..9b100db9 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt @@ -1,13 +1,6 @@ package sk.ainet.data.fashionmnist -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import sk.ainet.data.source.CachePolicy -import sk.ainet.data.source.DataSourceRequest -import sk.ainet.data.source.JvmDataSourceResolver -import java.io.ByteArrayInputStream -import java.util.zip.GZIPInputStream -import java.io.File +import sk.ainet.data.common.JvmDatasetSourceReader /** * JVM implementation of the Fashion-MNIST loader. @@ -15,7 +8,7 @@ import java.io.File * @property config The configuration for the Fashion-MNIST loader. */ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMNISTLoaderCommon(config) { - private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources")) + private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache) /** * Resolves, caches, and decompresses a file when needed. @@ -24,23 +17,8 @@ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMN * @param filename The name of the file to save. * @return The bytes of the decompressed file. */ - override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.IO) { - val artifact = resolver.resolve( - DataSourceRequest( - uri = url, - cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh - ) - ) - return@withContext maybeGunzip(artifact.readBytes()) - } - - private fun maybeGunzip(bytes: ByteArray): ByteArray { - if (!bytes.isGzip()) return bytes - return GZIPInputStream(ByteArrayInputStream(bytes)).use { it.readBytes() } - } - - private fun ByteArray.isGzip(): Boolean { - return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() + override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray { + return sources.readGzipDecoded(url) } public companion object { diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt index 330197d2..e5466b4f 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt @@ -1,13 +1,6 @@ package sk.ainet.data.mnist -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import sk.ainet.data.source.CachePolicy -import sk.ainet.data.source.DataSourceRequest -import sk.ainet.data.source.JvmDataSourceResolver -import java.io.ByteArrayInputStream -import java.util.zip.GZIPInputStream -import java.io.File +import sk.ainet.data.common.JvmDatasetSourceReader /** * JVM implementation of the MNIST loader. @@ -15,7 +8,7 @@ import java.io.File * @property config The configuration for the MNIST loader. */ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(config) { - private val resolver = JvmDataSourceResolver(File(config.cacheDir, "sources")) + private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache) /** * Resolves, caches, and decompresses a file when needed. @@ -24,23 +17,8 @@ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(confi * @param filename The name of the file to save. * @return The bytes of the decompressed file. */ - override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.IO) { - val artifact = resolver.resolve( - DataSourceRequest( - uri = url, - cachePolicy = if (config.useCache) CachePolicy.Use else CachePolicy.Refresh - ) - ) - return@withContext maybeGunzip(artifact.readBytes()) - } - - private fun maybeGunzip(bytes: ByteArray): ByteArray { - if (!bytes.isGzip()) return bytes - return GZIPInputStream(ByteArrayInputStream(bytes)).use { it.readBytes() } - } - - private fun ByteArray.isGzip(): Boolean { - return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() + override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray { + return sources.readGzipDecoded(url) } public companion object { diff --git a/skainet-data/skainet-data-source/build.gradle.kts b/skainet-data/skainet-data-source/build.gradle.kts index f8b4dbc0..7e4ee92d 100644 --- a/skainet-data/skainet-data-source/build.gradle.kts +++ b/skainet-data/skainet-data-source/build.gradle.kts @@ -22,6 +22,7 @@ kotlin { commonTest.dependencies { implementation(libs.kotlin.test) + implementation(libs.kotlinx.coroutines.test) } jvmMain.dependencies { @@ -31,8 +32,5 @@ kotlin { implementation(libs.kotlinx.coroutines.core.jvm) } - jvmTest.dependencies { - implementation(libs.kotlinx.coroutines.test) - } } } diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt new file mode 100644 index 00000000..4524221e --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt @@ -0,0 +1,138 @@ +package sk.ainet.data.source + +/** + * Fetches a remote URI into memory. Kept injectable so tests and applications + * can provide their own HTTP stack or policy layer. + */ +public fun interface RemoteDataSourceFetcher { + public suspend fun fetch(uri: String, headers: Map): ByteArray +} + +/** + * Adds platform or application-specific headers to a resolved remote request. + */ +public fun interface DataSourceHeaderProvider { + public fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map +} + +/** + * Computes checksums for integrity verification without tying resolver policy + * to a concrete platform crypto API. + */ +public fun interface DataSourceChecksum { + public fun sha256Hex(bytes: ByteArray): String +} + +/** + * Platform storage adapter used by [DefaultDataSourceResolver]. + */ +public interface DataSourceByteStore { + public suspend fun readLocal(path: String): DataSourceStoredArtifact? + public suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? + public suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact +} + +/** + * A platform materialized artifact used by the common resolver core. + */ +public class DataSourceStoredArtifact( + public val localPath: String?, + public val sizeBytes: Long?, + private val byteReader: suspend () -> ByteArray +) { + public suspend fun readBytes(): ByteArray = byteReader() +} + +/** + * Platform-neutral resolver implementation for local files, HTTP(S), and + * Hugging Face source URIs. Storage, network, auth, and checksum details are + * injected so this policy can be reused by each KMP target. + */ +public class DefaultDataSourceResolver( + private val store: DataSourceByteStore, + private val fetcher: RemoteDataSourceFetcher, + private val checksum: DataSourceChecksum, + private val headerProvider: DataSourceHeaderProvider = DataSourceHeaderProvider { request, _ -> + request.headers + } +) : DataSourceResolver { + override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact { + val parsed = DataSourceUriParser.parse(request.uri) + return when (parsed.provider) { + DataSourceProvider.File -> resolveFile(request, parsed) + DataSourceProvider.Http, DataSourceProvider.HuggingFace -> resolveRemote(request, parsed) + } + } + + private suspend fun resolveFile( + request: DataSourceRequest, + parsed: ParsedDataSourceUri + ): DataSourceArtifact { + val path = parsed.localPath ?: throw DataSourceException("File source has no local path: ${request.uri}") + val stored = store.readLocal(path) + ?: throw DataSourceException("Data source file not found: $path") + request.expectedSha256?.let { verifySha256(stored.readBytes(), it, request.uri) } + return stored.toArtifact(request, parsed, cacheHit = true) + } + + private suspend fun resolveRemote( + request: DataSourceRequest, + parsed: ParsedDataSourceUri + ): DataSourceArtifact { + val canUseCache = request.cachePolicy == CachePolicy.Use || request.cachePolicy == CachePolicy.Offline + if (canUseCache) { + val cached = store.readCache(parsed.cacheKey) + if (cached != null) { + request.expectedSha256?.let { verifySha256(cached.readBytes(), it, request.uri) } + return cached.toArtifact(request, parsed, cacheHit = true) + } + } + + if (request.cachePolicy == CachePolicy.Offline) { + throw DataSourceException("No cached artifact available for offline source: ${request.uri}") + } + + val bytes = fetcher.fetch(parsed.transportUri, headerProvider.headers(request, parsed)) + request.expectedSha256?.let { verifySha256(bytes, it, request.uri) } + + if (request.cachePolicy == CachePolicy.Bypass) { + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = null, + sizeBytes = bytes.size.toLong(), + cacheHit = false, + byteReader = { bytes } + ) + } + + val stored = store.writeCache(parsed.cacheKey, bytes) + return stored.toArtifact(request, parsed, cacheHit = false) + } + + private suspend fun DataSourceStoredArtifact.toArtifact( + request: DataSourceRequest, + parsed: ParsedDataSourceUri, + cacheHit: Boolean + ): DataSourceArtifact { + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = localPath, + sizeBytes = sizeBytes, + cacheHit = cacheHit, + byteReader = { readBytes() } + ) + } + + private fun verifySha256(bytes: ByteArray, expected: String, uri: String) { + val actual = checksum.sha256Hex(bytes) + if (!actual.equals(expected, ignoreCase = true)) { + throw DataSourceException( + "SHA-256 mismatch for $uri: expected ${expected.lowercase()}, actual $actual" + ) + } + } +} diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt new file mode 100644 index 00000000..1b864465 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt @@ -0,0 +1,164 @@ +package sk.ainet.data.source + +import kotlinx.coroutines.test.runTest +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class DefaultDataSourceResolverTest { + @Test + fun resolvesLocalArtifactsThroughStore() = runTest { + val store = MemoryDataSourceByteStore( + localArtifacts = mapOf("/data/sample.bin" to "local".encodeToByteArray()) + ) + val fetcher = RecordingFetcher("remote".encodeToByteArray()) + val resolver = DefaultDataSourceResolver(store, fetcher, TestChecksum) + + val artifact = resolver.resolve(DataSourceRequest("/data/sample.bin")) + + assertEquals("/data/sample.bin", artifact.localPath) + assertTrue(artifact.cacheHit) + assertEquals(0, fetcher.calls) + assertContentEquals("local".encodeToByteArray(), artifact.readBytes()) + } + + @Test + fun cachesRemoteArtifactsAndReusesThem() = runTest { + val store = MemoryDataSourceByteStore() + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver(store, fetcher, TestChecksum) + val request = DataSourceRequest( + uri = "https://example.test/data.bin", + expectedSha256 = "sha:payload" + ) + + val first = resolver.resolve(request) + val second = resolver.resolve(request) + + assertFalse(first.cacheHit) + assertTrue(second.cacheHit) + assertEquals(1, fetcher.calls) + assertEquals(1, store.cacheWrites) + assertContentEquals("payload".encodeToByteArray(), second.readBytes()) + } + + @Test + fun bypassSkipsPersistentCache() = runTest { + val store = MemoryDataSourceByteStore() + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver(store, fetcher, TestChecksum) + + val artifact = resolver.resolve( + DataSourceRequest( + uri = "https://example.test/data.bin", + cachePolicy = CachePolicy.Bypass + ) + ) + + assertEquals(null, artifact.localPath) + assertFalse(artifact.cacheHit) + assertEquals(1, fetcher.calls) + assertEquals(0, store.cacheWrites) + } + + @Test + fun verifiesChecksumsInCommonCore() = runTest { + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = RecordingFetcher("payload".encodeToByteArray()), + checksum = TestChecksum + ) + + assertFailsWith { + resolver.resolve( + DataSourceRequest( + uri = "https://example.test/data.bin", + expectedSha256 = "sha:other" + ) + ) + } + } + + @Test + fun forwardsProviderHeadersToFetcher() = runTest { + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = fetcher, + checksum = TestChecksum, + headerProvider = DataSourceHeaderProvider { request, parsedUri -> + request.headers + ("X-SKaiNET-Provider" to parsedUri.provider.name) + } + ) + + resolver.resolve( + DataSourceRequest( + uri = "hf://datasets/org/repo@main/file.bin", + headers = mapOf("Accept" to "application/octet-stream") + ) + ) + + assertEquals( + mapOf( + "Accept" to "application/octet-stream", + "X-SKaiNET-Provider" to "HuggingFace" + ), + fetcher.lastHeaders + ) + } +} + +private class MemoryDataSourceByteStore( + private val localArtifacts: Map = emptyMap() +) : DataSourceByteStore { + private val cacheArtifacts = mutableMapOf() + + var cacheWrites: Int = 0 + private set + + override suspend fun readLocal(path: String): DataSourceStoredArtifact? { + return localArtifacts[path]?.storedAt(path) + } + + override suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? { + return cacheArtifacts[cacheKey]?.storedAt("/cache/$cacheKey") + } + + override suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact { + cacheWrites++ + cacheArtifacts[cacheKey] = bytes + return bytes.storedAt("/cache/$cacheKey") + } + + private fun ByteArray.storedAt(path: String): DataSourceStoredArtifact { + val bytes = copyOf() + return DataSourceStoredArtifact( + localPath = path, + sizeBytes = bytes.size.toLong(), + byteReader = { bytes.copyOf() } + ) + } +} + +private class RecordingFetcher( + private val bytes: ByteArray +) : RemoteDataSourceFetcher { + var calls: Int = 0 + private set + + var lastHeaders: Map = emptyMap() + private set + + override suspend fun fetch(uri: String, headers: Map): ByteArray { + calls++ + lastHeaders = headers + return bytes.copyOf() + } +} + +private object TestChecksum : DataSourceChecksum { + override fun sha256Hex(bytes: ByteArray): String = "sha:${bytes.decodeToString()}" +} diff --git a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt index 7cba502a..0dec564d 100644 --- a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt +++ b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt @@ -11,14 +11,6 @@ import kotlinx.coroutines.withContext import java.io.File import java.security.MessageDigest -/** - * Fetches a remote URI into memory. Kept injectable so tests and applications - * can provide their own HTTP stack or policy layer. - */ -public fun interface RemoteDataSourceFetcher { - public suspend fun fetch(uri: String, headers: Map): ByteArray -} - /** * Ktor/CIO-backed remote fetcher for JVM data artifacts. */ @@ -47,99 +39,70 @@ public class KtorRemoteDataSourceFetcher( * JVM resolver for local files and cached remote artifacts. */ public class JvmDataSourceResolver( - private val cacheDir: File = defaultCacheDir(), - private val fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher() + cacheDir: File = defaultCacheDir(), + fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher() ) : DataSourceResolver { - override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact = withContext(Dispatchers.IO) { - val parsed = DataSourceUriParser.parse(request.uri) - when (parsed.provider) { - DataSourceProvider.File -> resolveFile(request, parsed) - DataSourceProvider.Http, DataSourceProvider.HuggingFace -> resolveRemote(request, parsed) - } - } + private val delegate = DefaultDataSourceResolver( + store = JvmFileDataSourceByteStore(cacheDir), + fetcher = fetcher, + checksum = JvmSha256DataSourceChecksum, + headerProvider = JvmHuggingFaceHeaderProvider + ) - private fun resolveFile( - request: DataSourceRequest, - parsed: ParsedDataSourceUri - ): DataSourceArtifact { - val path = parsed.localPath ?: throw DataSourceException("File source has no local path: ${request.uri}") - val file = File(path) - require(file.exists()) { "Data source file not found: ${file.absolutePath}" } - require(file.isFile) { "Data source path is not a file: ${file.absolutePath}" } - request.expectedSha256?.let { verifySha256(file.readBytes(), it, request.uri) } - return DataSourceArtifact( - request = request, - parsedUri = parsed, - filename = parsed.filename, - localPath = file.absolutePath, - sizeBytes = file.length(), - cacheHit = true, - byteReader = { file.readBytes() } - ) + override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact = withContext(Dispatchers.IO) { + delegate.resolve(request) } - private suspend fun resolveRemote( - request: DataSourceRequest, - parsed: ParsedDataSourceUri - ): DataSourceArtifact { - val target = File(cacheDir, parsed.cacheKey) - val canUseCache = request.cachePolicy == CachePolicy.Use || request.cachePolicy == CachePolicy.Offline - if (canUseCache && target.exists() && target.isFile) { - request.expectedSha256?.let { verifySha256(target.readBytes(), it, request.uri) } - return cachedArtifact(request, parsed, target, cacheHit = true) + public companion object { + public fun defaultCacheDir(): File { + val userHome = System.getProperty("user.home")?.takeIf { it.isNotBlank() } + val base = userHome ?: System.getProperty("java.io.tmpdir") + return File(base, ".cache/skainet/data") } + } +} - if (request.cachePolicy == CachePolicy.Offline) { - throw DataSourceException("No cached artifact available for offline source: ${request.uri}") +internal class JvmFileDataSourceByteStore( + private val cacheDir: File +) : DataSourceByteStore { + override suspend fun readLocal(path: String): DataSourceStoredArtifact? { + val file = File(path) + if (!file.exists()) return null + if (!file.isFile) { + throw DataSourceException("Data source path is not a file: ${file.absolutePath}") } + return file.toStoredArtifact() + } - val bytes = fetcher.fetch(parsed.transportUri, requestHeaders(request, parsed)) - request.expectedSha256?.let { verifySha256(bytes, it, request.uri) } - - if (request.cachePolicy == CachePolicy.Bypass) { - return DataSourceArtifact( - request = request, - parsedUri = parsed, - filename = parsed.filename, - localPath = null, - sizeBytes = bytes.size.toLong(), - cacheHit = false, - byteReader = { bytes } - ) - } + override suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? { + val target = File(cacheDir, cacheKey) + return if (target.exists() && target.isFile) target.toStoredArtifact() else null + } + override suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact { cacheDir.mkdirs() - val temp = File(cacheDir, "${parsed.cacheKey}.tmp") + val target = File(cacheDir, cacheKey) + val temp = File(cacheDir, "$cacheKey.tmp") temp.writeBytes(bytes) if (!temp.renameTo(target)) { temp.copyTo(target, overwrite = true) temp.delete() } - return cachedArtifact(request, parsed, target, cacheHit = false) + return target.toStoredArtifact() } - private fun cachedArtifact( - request: DataSourceRequest, - parsed: ParsedDataSourceUri, - target: File, - cacheHit: Boolean - ): DataSourceArtifact { - return DataSourceArtifact( - request = request, - parsedUri = parsed, - filename = parsed.filename, - localPath = target.absolutePath, - sizeBytes = target.length(), - cacheHit = cacheHit, - byteReader = { target.readBytes() } + private fun File.toStoredArtifact(): DataSourceStoredArtifact { + return DataSourceStoredArtifact( + localPath = absolutePath, + sizeBytes = length(), + byteReader = { readBytes() } ) } +} - private fun requestHeaders( - request: DataSourceRequest, - parsed: ParsedDataSourceUri - ): Map { - if (parsed.provider != DataSourceProvider.HuggingFace) return request.headers +internal object JvmHuggingFaceHeaderProvider : DataSourceHeaderProvider { + override fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map { + if (parsedUri.provider != DataSourceProvider.HuggingFace) return request.headers if (request.headers.keys.any { it.equals("Authorization", ignoreCase = true) }) return request.headers val token = System.getenv("HF_TOKEN") ?.takeIf { it.isNotBlank() } @@ -147,23 +110,12 @@ public class JvmDataSourceResolver( ?: return request.headers return request.headers + ("Authorization" to "Bearer $token") } +} - private fun verifySha256(bytes: ByteArray, expected: String, uri: String) { - val actual = MessageDigest.getInstance("SHA-256") +internal object JvmSha256DataSourceChecksum : DataSourceChecksum { + override fun sha256Hex(bytes: ByteArray): String { + return MessageDigest.getInstance("SHA-256") .digest(bytes) .joinToString("") { byte -> "%02x".format(byte) } - if (!actual.equals(expected, ignoreCase = true)) { - throw DataSourceException( - "SHA-256 mismatch for $uri: expected ${expected.lowercase()}, actual $actual" - ) - } - } - - public companion object { - public fun defaultCacheDir(): File { - val userHome = System.getProperty("user.home")?.takeIf { it.isNotBlank() } - val base = userHome ?: System.getProperty("java.io.tmpdir") - return File(base, ".cache/skainet/data") - } } } From ed357d7dfc000e0ea99846d20a52733c3a70ee21 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 15:00:54 +0200 Subject: [PATCH 6/7] data: stream source artifacts with kotlinx-io --- .../data-sources-getting-started.adoc | 14 +- .../skainet-data-source/build.gradle.kts | 1 + .../sk/ainet/data/source/DataSourceModels.kt | 37 ++- .../data/source/DefaultDataSourceResolver.kt | 269 ++++++++++++++++-- .../source/DefaultDataSourceResolverTest.kt | 56 +++- .../data/source/JvmDataSourceResolver.kt | 72 ++--- .../data/source/JvmDataSourceResolverTest.kt | 10 +- 7 files changed, 361 insertions(+), 98 deletions(-) diff --git a/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc index b4001501..34eb72a7 100644 --- a/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc +++ b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc @@ -40,9 +40,9 @@ dependencies { === Resolve one artifact `JvmDataSourceResolver` materializes remote artifacts into a cache and returns -a `DataSourceArtifact` that can be read as bytes. Public Hugging Face files do -not need credentials. Private files can use an `Authorization` header, or the -JVM resolver will read `HF_TOKEN` / `HUGGING_FACE_HUB_TOKEN` from the +a `DataSourceArtifact` that opens a `kotlinx.io.Source`. Public Hugging Face +files do not need credentials. Private files can use an `Authorization` header, +or the JVM resolver will read `HF_TOKEN` / `HUGGING_FACE_HUB_TOKEN` from the environment when the URI provider is Hugging Face. [source,kotlin] @@ -60,6 +60,14 @@ val artifact = resolver.resolve( println(artifact.filename) println(artifact.localPath) +val source = artifact.openSource() +try { + // Pass the source to a parser/loader for model-sized artifacts. +} finally { + source.close() +} + +// Convenience for small sidecars and tests. val bytes = artifact.readBytes() ---- diff --git a/skainet-data/skainet-data-source/build.gradle.kts b/skainet-data/skainet-data-source/build.gradle.kts index 7e4ee92d..7d7c3065 100644 --- a/skainet-data/skainet-data-source/build.gradle.kts +++ b/skainet-data/skainet-data-source/build.gradle.kts @@ -18,6 +18,7 @@ kotlin { sourceSets { commonMain.dependencies { implementation(libs.kotlinx.coroutines) + implementation(libs.kotlinx.io.core) } commonTest.dependencies { diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt index 71eacf32..0e34245d 100644 --- a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt @@ -1,5 +1,9 @@ package sk.ainet.data.source +import kotlinx.io.Sink +import kotlinx.io.Source +import kotlinx.io.readByteArray + /** * Cache behavior requested by a caller resolving a data artifact. */ @@ -80,9 +84,38 @@ public class DataSourceArtifact( public val localPath: String?, public val sizeBytes: Long?, public val cacheHit: Boolean, - private val byteReader: suspend () -> ByteArray + private val sourceOpener: suspend () -> Source ) { - public suspend fun readBytes(): ByteArray = byteReader() + /** + * Opens a fresh source for this artifact. Callers own and must close it. + */ + public suspend fun openSource(): Source = sourceOpener() + + /** + * Convenience for small artifacts. Prefer [openSource] or [copyTo] for + * model-scale data. + */ + public suspend fun readBytes(): ByteArray { + val source = openSource() + return try { + source.readByteArray() + } finally { + source.close() + } + } + + /** + * Streams this artifact into [sink]. The source is closed after copying; + * [sink] is left open for the caller. + */ + public suspend fun copyTo(sink: Sink): Long { + val source = openSource() + return try { + source.transferTo(sink) + } finally { + source.close() + } + } } /** diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt index 4524221e..b7ee217f 100644 --- a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt @@ -1,11 +1,38 @@ package sk.ainet.data.source +import kotlinx.io.Buffer +import kotlinx.io.RawSource +import kotlinx.io.Sink +import kotlinx.io.Source +import kotlinx.io.buffered +import kotlinx.io.files.FileSystem +import kotlinx.io.files.Path +import kotlinx.io.files.SystemFileSystem +import kotlinx.io.readByteArray + +/** + * Remote response body exposed as a one-shot [Source]. + */ +public class DataSourceRemoteContent( + public val source: Source, + public val sizeBytes: Long? = null +) { + public companion object { + public fun fromBytes(bytes: ByteArray): DataSourceRemoteContent { + return DataSourceRemoteContent( + source = bytes.toDataSourceSource(), + sizeBytes = bytes.size.toLong() + ) + } + } +} + /** - * Fetches a remote URI into memory. Kept injectable so tests and applications + * Fetches a remote URI as a stream. Kept injectable so tests and applications * can provide their own HTTP stack or policy layer. */ public fun interface RemoteDataSourceFetcher { - public suspend fun fetch(uri: String, headers: Map): ByteArray + public suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent } /** @@ -19,28 +46,153 @@ public fun interface DataSourceHeaderProvider { * Computes checksums for integrity verification without tying resolver policy * to a concrete platform crypto API. */ -public fun interface DataSourceChecksum { - public fun sha256Hex(bytes: ByteArray): String +public interface DataSourceChecksum { + public fun newSha256(): DataSourceHash +} + +/** + * Incremental hash state used while streaming artifact bytes. + */ +public interface DataSourceHash { + public fun update(bytes: ByteArray, startIndex: Int = 0, endIndex: Int = bytes.size) + public fun hex(): String } /** * Platform storage adapter used by [DefaultDataSourceResolver]. */ -public interface DataSourceByteStore { +public interface DataSourceArtifactStore { public suspend fun readLocal(path: String): DataSourceStoredArtifact? public suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? - public suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact + public suspend fun writeCache( + cacheKey: String, + source: Source, + sizeBytes: Long? = null, + validate: suspend (DataSourceStoredArtifact) -> Unit = {} + ): DataSourceStoredArtifact } /** - * A platform materialized artifact used by the common resolver core. + * A materialized artifact used by the common resolver core. */ public class DataSourceStoredArtifact( public val localPath: String?, public val sizeBytes: Long?, - private val byteReader: suspend () -> ByteArray + private val sourceOpener: suspend () -> Source ) { - public suspend fun readBytes(): ByteArray = byteReader() + public suspend fun openSource(): Source = sourceOpener() + + public suspend fun readBytes(): ByteArray { + val source = openSource() + return try { + source.readByteArray() + } finally { + source.close() + } + } + + public suspend fun copyTo(sink: Sink): Long { + val source = openSource() + return try { + source.transferTo(sink) + } finally { + source.close() + } + } + + public companion object { + public fun inMemory( + bytes: ByteArray, + localPath: String? = null + ): DataSourceStoredArtifact { + val owned = bytes.copyOf() + return DataSourceStoredArtifact( + localPath = localPath, + sizeBytes = owned.size.toLong(), + sourceOpener = { owned.toDataSourceSource() } + ) + } + + public fun inMemoryFrom( + source: Source, + localPath: String? = null, + sizeBytes: Long? = null + ): DataSourceStoredArtifact { + val buffer = Buffer() + val copied = source.transferTo(buffer) + return DataSourceStoredArtifact( + localPath = localPath, + sizeBytes = sizeBytes ?: copied, + sourceOpener = { buffer.copy() } + ) + } + } +} + +/** + * Filesystem-backed artifact store built on kotlinx-io so the cache policy + * remains reusable across KMP targets that expose [SystemFileSystem]. + */ +public class FileSystemDataSourceArtifactStore( + private val cacheDir: Path, + private val fileSystem: FileSystem = SystemFileSystem +) : DataSourceArtifactStore { + override suspend fun readLocal(path: String): DataSourceStoredArtifact? { + val localPath = Path(path) + val metadata = fileSystem.metadataOrNull(localPath) ?: return null + if (!metadata.isRegularFile) { + throw DataSourceException("Data source path is not a file: $path") + } + val resolved = fileSystem.resolve(localPath) + return resolved.toStoredArtifact(metadata.size) + } + + override suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? { + val target = Path(cacheDir, cacheKey) + val metadata = fileSystem.metadataOrNull(target) ?: return null + return if (metadata.isRegularFile) target.toStoredArtifact(metadata.size) else null + } + + override suspend fun writeCache( + cacheKey: String, + source: Source, + sizeBytes: Long?, + validate: suspend (DataSourceStoredArtifact) -> Unit + ): DataSourceStoredArtifact { + fileSystem.createDirectories(cacheDir) + val target = Path(cacheDir, cacheKey) + val temp = Path(cacheDir, "$cacheKey.tmp") + + val sink = fileSystem.sink(temp).buffered() + try { + source.transferTo(sink) + sink.flush() + } finally { + sink.close() + } + + val tempMetadata = fileSystem.metadataOrNull(temp) + val tempArtifact = temp.toStoredArtifact(tempMetadata?.size ?: sizeBytes) + try { + validate(tempArtifact) + } catch (throwable: Throwable) { + fileSystem.delete(temp, mustExist = false) + throw throwable + } + + fileSystem.atomicMove(temp, target) + val metadata = fileSystem.metadataOrNull(target) + return target.toStoredArtifact(metadata?.size ?: sizeBytes) + } + + private fun Path.toStoredArtifact(sizeBytes: Long?): DataSourceStoredArtifact { + val path = this + return DataSourceStoredArtifact( + localPath = path.toString(), + sizeBytes = sizeBytes, + sourceOpener = { fileSystem.source(path).buffered() } + ) + } } /** @@ -49,7 +201,7 @@ public class DataSourceStoredArtifact( * injected so this policy can be reused by each KMP target. */ public class DefaultDataSourceResolver( - private val store: DataSourceByteStore, + private val store: DataSourceArtifactStore, private val fetcher: RemoteDataSourceFetcher, private val checksum: DataSourceChecksum, private val headerProvider: DataSourceHeaderProvider = DataSourceHeaderProvider { request, _ -> @@ -71,7 +223,7 @@ public class DefaultDataSourceResolver( val path = parsed.localPath ?: throw DataSourceException("File source has no local path: ${request.uri}") val stored = store.readLocal(path) ?: throw DataSourceException("Data source file not found: $path") - request.expectedSha256?.let { verifySha256(stored.readBytes(), it, request.uri) } + request.expectedSha256?.let { verifySha256(stored, it, request.uri) } return stored.toArtifact(request, parsed, cacheHit = true) } @@ -83,7 +235,7 @@ public class DefaultDataSourceResolver( if (canUseCache) { val cached = store.readCache(parsed.cacheKey) if (cached != null) { - request.expectedSha256?.let { verifySha256(cached.readBytes(), it, request.uri) } + request.expectedSha256?.let { verifySha256(cached, it, request.uri) } return cached.toArtifact(request, parsed, cacheHit = true) } } @@ -92,22 +244,35 @@ public class DefaultDataSourceResolver( throw DataSourceException("No cached artifact available for offline source: ${request.uri}") } - val bytes = fetcher.fetch(parsed.transportUri, headerProvider.headers(request, parsed)) - request.expectedSha256?.let { verifySha256(bytes, it, request.uri) } + val remote = fetcher.fetch(parsed.transportUri, headerProvider.headers(request, parsed)) if (request.cachePolicy == CachePolicy.Bypass) { - return DataSourceArtifact( - request = request, - parsedUri = parsed, - filename = parsed.filename, - localPath = null, - sizeBytes = bytes.size.toLong(), - cacheHit = false, - byteReader = { bytes } - ) + val stored = try { + DataSourceStoredArtifact.inMemoryFrom(remote.source, sizeBytes = remote.sizeBytes) + } finally { + remote.source.close() + } + request.expectedSha256?.let { verifySha256(stored, it, request.uri) } + return stored.toArtifact(request, parsed, cacheHit = false) } - val stored = store.writeCache(parsed.cacheKey, bytes) + val expectedSha256 = request.expectedSha256 + val hash = expectedSha256?.let { checksum.newSha256() } + val source = hash?.let { HashingRawSource(remote.source, it).buffered() } ?: remote.source + val stored = try { + store.writeCache( + cacheKey = parsed.cacheKey, + source = source, + sizeBytes = remote.sizeBytes, + validate = { + if (expectedSha256 != null && hash != null) { + verifySha256Hex(hash.hex(), expectedSha256, request.uri) + } + } + ) + } finally { + source.close() + } return stored.toArtifact(request, parsed, cacheHit = false) } @@ -123,16 +288,66 @@ public class DefaultDataSourceResolver( localPath = localPath, sizeBytes = sizeBytes, cacheHit = cacheHit, - byteReader = { readBytes() } + sourceOpener = { openSource() } ) } - private fun verifySha256(bytes: ByteArray, expected: String, uri: String) { - val actual = checksum.sha256Hex(bytes) + private suspend fun verifySha256(artifact: DataSourceStoredArtifact, expected: String, uri: String) { + val actual = artifact.sha256Hex() + verifySha256Hex(actual, expected, uri) + } + + private fun verifySha256Hex(actual: String, expected: String, uri: String) { if (!actual.equals(expected, ignoreCase = true)) { throw DataSourceException( "SHA-256 mismatch for $uri: expected ${expected.lowercase()}, actual $actual" ) } } + + private suspend fun DataSourceStoredArtifact.sha256Hex(): String { + val hash = checksum.newSha256() + val buffer = ByteArray(STREAM_BUFFER_SIZE) + val source = openSource() + try { + while (true) { + val read = source.readAtMostTo(buffer) + if (read == -1) break + hash.update(buffer, endIndex = read) + } + } finally { + source.close() + } + return hash.hex() + } + + private companion object { + private const val STREAM_BUFFER_SIZE = 8 * 1024 + } +} + +private class HashingRawSource( + private val source: Source, + private val hash: DataSourceHash +) : RawSource { + override fun readAtMostTo(sink: Buffer, byteCount: Long): Long { + val start = sink.size + val read = source.readAtMostTo(sink, byteCount) + if (read > 0) { + val copied = Buffer() + sink.copyTo(copied, startIndex = start, endIndex = start + read) + hash.update(copied.readByteArray()) + } + return read + } + + override fun close() { + source.close() + } +} + +private fun ByteArray.toDataSourceSource(): Source { + val buffer = Buffer() + buffer.write(this) + return buffer } diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt index 1b864465..01e4924e 100644 --- a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt @@ -1,6 +1,9 @@ package sk.ainet.data.source import kotlinx.coroutines.test.runTest +import kotlinx.io.Buffer +import kotlinx.io.Source +import kotlinx.io.readByteArray import kotlin.test.Test import kotlin.test.assertContentEquals import kotlin.test.assertEquals @@ -23,6 +26,10 @@ class DefaultDataSourceResolverTest { assertTrue(artifact.cacheHit) assertEquals(0, fetcher.calls) assertContentEquals("local".encodeToByteArray(), artifact.readBytes()) + + val copied = Buffer() + assertEquals(5, artifact.copyTo(copied)) + assertContentEquals("local".encodeToByteArray(), copied.readByteArray()) } @Test @@ -66,8 +73,9 @@ class DefaultDataSourceResolverTest { @Test fun verifiesChecksumsInCommonCore() = runTest { + val store = MemoryDataSourceByteStore() val resolver = DefaultDataSourceResolver( - store = MemoryDataSourceByteStore(), + store = store, fetcher = RecordingFetcher("payload".encodeToByteArray()), checksum = TestChecksum ) @@ -80,6 +88,7 @@ class DefaultDataSourceResolverTest { ) ) } + assertEquals(0, store.cacheWrites) } @Test @@ -113,8 +122,8 @@ class DefaultDataSourceResolverTest { private class MemoryDataSourceByteStore( private val localArtifacts: Map = emptyMap() -) : DataSourceByteStore { - private val cacheArtifacts = mutableMapOf() +) : DataSourceArtifactStore { + private val cacheArtifacts = mutableMapOf() var cacheWrites: Int = 0 private set @@ -124,22 +133,29 @@ private class MemoryDataSourceByteStore( } override suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? { - return cacheArtifacts[cacheKey]?.storedAt("/cache/$cacheKey") + return cacheArtifacts[cacheKey] } - override suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact { + override suspend fun writeCache( + cacheKey: String, + source: Source, + sizeBytes: Long?, + validate: suspend (DataSourceStoredArtifact) -> Unit + ): DataSourceStoredArtifact { + val stored = DataSourceStoredArtifact.inMemoryFrom( + source = source, + localPath = "/cache/$cacheKey", + sizeBytes = sizeBytes + ) + validate(stored) cacheWrites++ - cacheArtifacts[cacheKey] = bytes - return bytes.storedAt("/cache/$cacheKey") + cacheArtifacts[cacheKey] = stored + return stored } private fun ByteArray.storedAt(path: String): DataSourceStoredArtifact { val bytes = copyOf() - return DataSourceStoredArtifact( - localPath = path, - sizeBytes = bytes.size.toLong(), - byteReader = { bytes.copyOf() } - ) + return DataSourceStoredArtifact.inMemory(bytes, localPath = path) } } @@ -152,13 +168,23 @@ private class RecordingFetcher( var lastHeaders: Map = emptyMap() private set - override suspend fun fetch(uri: String, headers: Map): ByteArray { + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { calls++ lastHeaders = headers - return bytes.copyOf() + return DataSourceRemoteContent.fromBytes(bytes.copyOf()) } } private object TestChecksum : DataSourceChecksum { - override fun sha256Hex(bytes: ByteArray): String = "sha:${bytes.decodeToString()}" + override fun newSha256(): DataSourceHash = TestHash() +} + +private class TestHash : DataSourceHash { + private val text = StringBuilder() + + override fun update(bytes: ByteArray, startIndex: Int, endIndex: Int) { + text.append(bytes.copyOfRange(startIndex, endIndex).decodeToString()) + } + + override fun hex(): String = "sha:$text" } diff --git a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt index 0dec564d..2a148bf4 100644 --- a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt +++ b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt @@ -1,11 +1,15 @@ package sk.ainet.data.source import io.ktor.client.HttpClient -import io.ktor.client.call.body import io.ktor.client.engine.cio.CIO import io.ktor.client.plugins.HttpTimeout import io.ktor.client.request.get import io.ktor.client.request.header +import io.ktor.client.statement.bodyAsChannel +import io.ktor.http.HttpHeaders +import io.ktor.utils.io.asSource +import kotlinx.io.buffered +import kotlinx.io.files.Path import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext import java.io.File @@ -24,10 +28,14 @@ public class KtorRemoteDataSourceFetcher( } } ) : RemoteDataSourceFetcher, AutoCloseable { - override suspend fun fetch(uri: String, headers: Map): ByteArray { - return client.get(uri) { + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { + val response = client.get(uri) { headers.forEach { (name, value) -> header(name, value) } - }.body() + } + return DataSourceRemoteContent( + source = response.bodyAsChannel().asSource().buffered(), + sizeBytes = response.headers[HttpHeaders.ContentLength]?.toLongOrNull() + ) } override fun close() { @@ -43,7 +51,7 @@ public class JvmDataSourceResolver( fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher() ) : DataSourceResolver { private val delegate = DefaultDataSourceResolver( - store = JvmFileDataSourceByteStore(cacheDir), + store = FileSystemDataSourceArtifactStore(Path(cacheDir.absolutePath)), fetcher = fetcher, checksum = JvmSha256DataSourceChecksum, headerProvider = JvmHuggingFaceHeaderProvider @@ -62,44 +70,6 @@ public class JvmDataSourceResolver( } } -internal class JvmFileDataSourceByteStore( - private val cacheDir: File -) : DataSourceByteStore { - override suspend fun readLocal(path: String): DataSourceStoredArtifact? { - val file = File(path) - if (!file.exists()) return null - if (!file.isFile) { - throw DataSourceException("Data source path is not a file: ${file.absolutePath}") - } - return file.toStoredArtifact() - } - - override suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? { - val target = File(cacheDir, cacheKey) - return if (target.exists() && target.isFile) target.toStoredArtifact() else null - } - - override suspend fun writeCache(cacheKey: String, bytes: ByteArray): DataSourceStoredArtifact { - cacheDir.mkdirs() - val target = File(cacheDir, cacheKey) - val temp = File(cacheDir, "$cacheKey.tmp") - temp.writeBytes(bytes) - if (!temp.renameTo(target)) { - temp.copyTo(target, overwrite = true) - temp.delete() - } - return target.toStoredArtifact() - } - - private fun File.toStoredArtifact(): DataSourceStoredArtifact { - return DataSourceStoredArtifact( - localPath = absolutePath, - sizeBytes = length(), - byteReader = { readBytes() } - ) - } -} - internal object JvmHuggingFaceHeaderProvider : DataSourceHeaderProvider { override fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map { if (parsedUri.provider != DataSourceProvider.HuggingFace) return request.headers @@ -113,9 +83,19 @@ internal object JvmHuggingFaceHeaderProvider : DataSourceHeaderProvider { } internal object JvmSha256DataSourceChecksum : DataSourceChecksum { - override fun sha256Hex(bytes: ByteArray): String { - return MessageDigest.getInstance("SHA-256") - .digest(bytes) + override fun newSha256(): DataSourceHash = JvmSha256DataSourceHash() +} + +private class JvmSha256DataSourceHash : DataSourceHash { + private val digest = MessageDigest.getInstance("SHA-256") + + override fun update(bytes: ByteArray, startIndex: Int, endIndex: Int) { + digest.update(bytes, startIndex, endIndex - startIndex) + } + + override fun hex(): String { + return digest + .digest() .joinToString("") { byte -> "%02x".format(byte) } } } diff --git a/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt index 5b44523e..f0102d38 100644 --- a/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt +++ b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt @@ -22,7 +22,7 @@ class JvmDataSourceResolverTest { val artifact = resolver.resolve(DataSourceRequest(file.toURI().toString())) assertEquals("sample.txt", artifact.filename) - assertEquals(file.absolutePath, artifact.localPath) + assertEquals(file.canonicalPath, artifact.localPath) assertTrue(artifact.cacheHit) assertContentEquals("hello".encodeToByteArray(), artifact.readBytes()) } finally { @@ -142,9 +142,9 @@ private class FakeFetcher( var calls: Int = 0 private set - override suspend fun fetch(uri: String, headers: Map): ByteArray { + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { calls++ - return bytes + return DataSourceRemoteContent.fromBytes(bytes) } } @@ -154,9 +154,9 @@ private class QueueFetcher( var calls: Int = 0 private set - override suspend fun fetch(uri: String, headers: Map): ByteArray { + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { val index = calls.coerceAtMost(responses.lastIndex) calls++ - return responses[index] + return DataSourceRemoteContent.fromBytes(responses[index]) } } From c6c9be3854f03fe475b36d156a44ed70034b06d5 Mon Sep 17 00:00:00 2001 From: Michal Harakal Date: Mon, 29 Jun 2026 19:31:58 +0200 Subject: [PATCH 7/7] data: parameterize Hugging Face auth --- .../data-sources-getting-started.adoc | 38 +++++++++-- .../sk/ainet/data/cifar10/CIFAR10Data.kt | 5 +- .../common/DatasetHuggingFaceTokenProvider.kt | 9 +++ .../data/fashionmnist/FashionMNISTData.kt | 5 +- .../kotlin/sk/ainet/data/mnist/MNISTData.kt | 5 +- .../sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt | 7 +- .../data/common/JvmDatasetSourceReader.kt | 13 +++- .../fashionmnist/FashionMNISTLoaderJvm.kt | 7 +- .../sk/ainet/data/mnist/MNISTLoaderJvm.kt | 7 +- .../sk/ainet/data/source/DataSourceModels.kt | 30 ++++++++- .../data/source/DefaultDataSourceResolver.kt | 34 +++++++++- .../source/DefaultDataSourceResolverTest.kt | 67 +++++++++++++++++++ .../data/source/JvmDataSourceResolver.kt | 30 ++++++--- .../data/source/JvmDataSourceResolverTest.kt | 49 ++++++++++++++ 14 files changed, 277 insertions(+), 29 deletions(-) create mode 100644 skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetHuggingFaceTokenProvider.kt diff --git a/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc index 34eb72a7..4557e875 100644 --- a/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc +++ b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc @@ -41,16 +41,20 @@ dependencies { `JvmDataSourceResolver` materializes remote artifacts into a cache and returns a `DataSourceArtifact` that opens a `kotlinx.io.Source`. Public Hugging Face -files do not need credentials. Private files can use an `Authorization` header, -or the JVM resolver will read `HF_TOKEN` / `HUGGING_FACE_HUB_TOKEN` from the -environment when the URI provider is Hugging Face. +files do not need credentials. Private files should pass an explicit +`DataSourceAuthToken` on the request or resolver. Existing `Authorization` +headers still take precedence. On JVM, the resolver can also read `HF_TOKEN` / +`HUGGING_FACE_HUB_TOKEN` from the environment as an opt-in convenience fallback. [source,kotlin] ---- +import sk.ainet.data.source.DataSourceAuthToken import sk.ainet.data.source.DataSourceRequest import sk.ainet.data.source.JvmDataSourceResolver -val resolver = JvmDataSourceResolver() +val resolver = JvmDataSourceResolver( + huggingFaceToken = DataSourceAuthToken.from("hf_...") +) val artifact = resolver.resolve( DataSourceRequest( uri = "hf+https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json" @@ -71,6 +75,28 @@ try { val bytes = artifact.readBytes() ---- +For per-request credentials, pass the token directly on `DataSourceRequest`. +This is useful when one resolver works with more than one private repository: + +[source,kotlin] +---- +val privateArtifact = resolver.resolve( + DataSourceRequest( + uri = "hf://datasets/your-org/private-dataset@main/data/train.bin", + huggingFaceToken = DataSourceAuthToken.from("hf_...") + ) +) +---- + +To opt into JVM environment fallback: + +[source,kotlin] +---- +val resolver = JvmDataSourceResolver( + useEnvironmentHuggingFaceToken = true +) +---- + === Use sources with built-in loaders MNIST and Fashion-MNIST expose per-file URI overrides. CIFAR-10 exposes an @@ -82,10 +108,12 @@ locations, so existing code keeps working. import sk.ainet.data.mnist.MNIST import sk.ainet.data.mnist.MNISTLoaderConfig +val token = "hf_..." val train = MNIST.loadTrain( MNISTLoaderConfig( trainImagesUri = "file:///datasets/mnist/train-images-idx3-ubyte", - trainLabelsUri = "hf+https://huggingface.co/your-org/mnist-idx/resolve/main/train-labels-idx1-ubyte.gz" + trainLabelsUri = "hf+https://huggingface.co/your-org/mnist-idx/resolve/main/train-labels-idx1-ubyte.gz", + huggingFaceTokenProvider = { token } ) ) diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt index 9344e150..1c63b86f 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt @@ -6,6 +6,7 @@ import sk.ainet.context.DefaultDataExecutionContext import sk.ainet.context.ExecutionContext import sk.ainet.data.DataBatch import sk.ainet.data.Dataset +import sk.ainet.data.common.DatasetHuggingFaceTokenProvider import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.types.DType @@ -145,7 +146,9 @@ public data class CIFAR10Dataset( public data class CIFAR10LoaderConfig( val cacheDir: String = "cifar10-data", val useCache: Boolean = true, - val archiveUri: String = CIFAR10Constants.DOWNLOAD_URL + val archiveUri: String = CIFAR10Constants.DOWNLOAD_URL, + val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null, + val useEnvironmentHuggingFaceToken: Boolean = false ) /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetHuggingFaceTokenProvider.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetHuggingFaceTokenProvider.kt new file mode 100644 index 00000000..1efeacd1 --- /dev/null +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetHuggingFaceTokenProvider.kt @@ -0,0 +1,9 @@ +package sk.ainet.data.common + +/** + * Supplies a Hugging Face token for built-in dataset loaders when their source + * URIs point at private Hugging Face artifacts. + */ +public fun interface DatasetHuggingFaceTokenProvider { + public fun token(): String? +} diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt index bdc0e99e..c9286506 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt @@ -6,6 +6,7 @@ import sk.ainet.context.DefaultDataExecutionContext import sk.ainet.context.ExecutionContext import sk.ainet.data.DataBatch import sk.ainet.data.Dataset +import sk.ainet.data.common.DatasetHuggingFaceTokenProvider import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.types.DType @@ -150,7 +151,9 @@ public data class FashionMNISTLoaderConfig( val trainImagesUri: String = FashionMNISTConstants.TRAIN_IMAGES_URL, val trainLabelsUri: String = FashionMNISTConstants.TRAIN_LABELS_URL, val testImagesUri: String = FashionMNISTConstants.TEST_IMAGES_URL, - val testLabelsUri: String = FashionMNISTConstants.TEST_LABELS_URL + val testLabelsUri: String = FashionMNISTConstants.TEST_LABELS_URL, + val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null, + val useEnvironmentHuggingFaceToken: Boolean = false ) /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt index c4bfa76a..ae6881cf 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt @@ -6,6 +6,7 @@ import sk.ainet.context.DefaultDataExecutionContext import sk.ainet.context.ExecutionContext import sk.ainet.data.DataBatch import sk.ainet.data.Dataset +import sk.ainet.data.common.DatasetHuggingFaceTokenProvider import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.types.DType @@ -128,7 +129,9 @@ public data class MNISTLoaderConfig( val trainImagesUri: String = MNISTConstants.TRAIN_IMAGES_URL, val trainLabelsUri: String = MNISTConstants.TRAIN_LABELS_URL, val testImagesUri: String = MNISTConstants.TEST_IMAGES_URL, - val testLabelsUri: String = MNISTConstants.TEST_LABELS_URL + val testLabelsUri: String = MNISTConstants.TEST_LABELS_URL, + val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null, + val useEnvironmentHuggingFaceToken: Boolean = false ) /** diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt index 10fc06a3..5ecf29ee 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt @@ -15,7 +15,12 @@ import java.io.FileOutputStream * @property config The configuration for the CIFAR-10 loader. */ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon(config) { - private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache) + private val sources = JvmDatasetSourceReader( + cacheDir = config.cacheDir, + useCache = config.useCache, + huggingFaceTokenProvider = config.huggingFaceTokenProvider, + useEnvironmentHuggingFaceToken = config.useEnvironmentHuggingFaceToken + ) /** * Downloads the CIFAR-10 archive and extracts the specified batch file. diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt index 29ce6bbf..785910fd 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt @@ -1,6 +1,7 @@ package sk.ainet.data.common import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceAuthToken import sk.ainet.data.source.DataSourceRequest import sk.ainet.data.source.JvmDataSourceResolver import java.io.ByteArrayInputStream @@ -9,16 +10,22 @@ import java.util.zip.GZIPInputStream internal class JvmDatasetSourceReader( cacheDir: String, - useCache: Boolean + useCache: Boolean, + private val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null, + useEnvironmentHuggingFaceToken: Boolean = false ) { - private val resolver = JvmDataSourceResolver(File(cacheDir, "sources")) + private val resolver = JvmDataSourceResolver( + cacheDir = File(cacheDir, "sources"), + useEnvironmentHuggingFaceToken = useEnvironmentHuggingFaceToken + ) private val cachePolicy = if (useCache) CachePolicy.Use else CachePolicy.Refresh suspend fun read(uri: String): ByteArray { val artifact = resolver.resolve( DataSourceRequest( uri = uri, - cachePolicy = cachePolicy + cachePolicy = cachePolicy, + huggingFaceToken = DataSourceAuthToken.fromOrNull(huggingFaceTokenProvider?.token()) ) ) return artifact.readBytes() diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt index 9b100db9..27c598e6 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt @@ -8,7 +8,12 @@ import sk.ainet.data.common.JvmDatasetSourceReader * @property config The configuration for the Fashion-MNIST loader. */ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMNISTLoaderCommon(config) { - private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache) + private val sources = JvmDatasetSourceReader( + cacheDir = config.cacheDir, + useCache = config.useCache, + huggingFaceTokenProvider = config.huggingFaceTokenProvider, + useEnvironmentHuggingFaceToken = config.useEnvironmentHuggingFaceToken + ) /** * Resolves, caches, and decompresses a file when needed. diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt index e5466b4f..0457d66c 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt @@ -8,7 +8,12 @@ import sk.ainet.data.common.JvmDatasetSourceReader * @property config The configuration for the MNIST loader. */ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(config) { - private val sources = JvmDatasetSourceReader(config.cacheDir, config.useCache) + private val sources = JvmDatasetSourceReader( + cacheDir = config.cacheDir, + useCache = config.useCache, + huggingFaceTokenProvider = config.huggingFaceTokenProvider, + useEnvironmentHuggingFaceToken = config.useEnvironmentHuggingFaceToken + ) /** * Resolves, caches, and decompresses a file when needed. diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt index 0e34245d..12077c0b 100644 --- a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt @@ -50,6 +50,33 @@ public data class HuggingFaceLocation( public val path: String? ) +/** + * Authentication token for provider-specific data source requests. + * + * The raw value is intentionally hidden from [toString] output so tokens are + * not leaked when requests or configs are logged. + */ +public class DataSourceAuthToken private constructor( + private val value: String +) { + override fun toString(): String = "DataSourceAuthToken(***)" + + internal fun authorizationHeaderValue(): String = "Bearer $value" + + public companion object { + public fun from(value: String): DataSourceAuthToken { + val normalized = value.trim() + require(normalized.isNotEmpty()) { "Data source auth token cannot be blank" } + return DataSourceAuthToken(normalized) + } + + public fun fromOrNull(value: String?): DataSourceAuthToken? { + val normalized = value?.trim()?.takeIf { it.isNotEmpty() } ?: return null + return DataSourceAuthToken(normalized) + } + } +} + /** * A normalized, provider-aware source URI. */ @@ -70,7 +97,8 @@ public data class DataSourceRequest( public val uri: String, public val cachePolicy: CachePolicy = CachePolicy.Use, public val expectedSha256: String? = null, - public val headers: Map = emptyMap() + public val headers: Map = emptyMap(), + public val huggingFaceToken: DataSourceAuthToken? = null ) /** diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt index b7ee217f..484de79d 100644 --- a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt @@ -42,6 +42,32 @@ public fun interface DataSourceHeaderProvider { public fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map } +/** + * Supplies a Hugging Face token when a request does not carry one directly. + */ +public fun interface HuggingFaceTokenProvider { + public fun token(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): DataSourceAuthToken? +} + +/** + * Adds Hugging Face bearer auth from explicit request or resolver-level token + * configuration while leaving generic HTTP requests unchanged. + */ +public class HuggingFaceTokenHeaderProvider( + private val tokenProvider: HuggingFaceTokenProvider = HuggingFaceTokenProvider { _, _ -> null } +) : DataSourceHeaderProvider { + override fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map { + if (parsedUri.provider != DataSourceProvider.HuggingFace) return request.headers + if (request.headers.hasAuthorizationHeader()) return request.headers + val token = request.huggingFaceToken ?: tokenProvider.token(request, parsedUri) ?: return request.headers + return request.headers + (AUTHORIZATION_HEADER to token.authorizationHeaderValue()) + } + + private companion object { + private const val AUTHORIZATION_HEADER = "Authorization" + } +} + /** * Computes checksums for integrity verification without tying resolver policy * to a concrete platform crypto API. @@ -204,9 +230,7 @@ public class DefaultDataSourceResolver( private val store: DataSourceArtifactStore, private val fetcher: RemoteDataSourceFetcher, private val checksum: DataSourceChecksum, - private val headerProvider: DataSourceHeaderProvider = DataSourceHeaderProvider { request, _ -> - request.headers - } + private val headerProvider: DataSourceHeaderProvider = HuggingFaceTokenHeaderProvider() ) : DataSourceResolver { override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact { val parsed = DataSourceUriParser.parse(request.uri) @@ -326,6 +350,10 @@ public class DefaultDataSourceResolver( } } +private fun Map.hasAuthorizationHeader(): Boolean { + return keys.any { it.equals("Authorization", ignoreCase = true) } +} + private class HashingRawSource( private val source: Source, private val hash: DataSourceHash diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt index 01e4924e..e3803588 100644 --- a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt @@ -118,6 +118,73 @@ class DefaultDataSourceResolverTest { fetcher.lastHeaders ) } + + @Test + fun addsHuggingFaceTokenFromRequest() = runTest { + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = fetcher, + checksum = TestChecksum + ) + val token = DataSourceAuthToken.from("hf_request") + + resolver.resolve( + DataSourceRequest( + uri = "hf://datasets/org/repo@main/file.bin", + headers = mapOf("Accept" to "application/octet-stream"), + huggingFaceToken = token + ) + ) + + assertEquals( + mapOf( + "Accept" to "application/octet-stream", + "Authorization" to "Bearer hf_request" + ), + fetcher.lastHeaders + ) + assertEquals("DataSourceAuthToken(***)", token.toString()) + } + + @Test + fun keepsExistingAuthorizationHeaderOverHuggingFaceToken() = runTest { + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = fetcher, + checksum = TestChecksum + ) + + resolver.resolve( + DataSourceRequest( + uri = "hf://org/repo@main/file.bin", + headers = mapOf("authorization" to "Bearer explicit"), + huggingFaceToken = DataSourceAuthToken.from("hf_request") + ) + ) + + assertEquals(mapOf("authorization" to "Bearer explicit"), fetcher.lastHeaders) + } + + @Test + fun doesNotAddHuggingFaceTokenToGenericHttp() = runTest { + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = fetcher, + checksum = TestChecksum + ) + + resolver.resolve( + DataSourceRequest( + uri = "https://example.test/data.bin", + huggingFaceToken = DataSourceAuthToken.from("hf_request") + ) + ) + + assertEquals(emptyMap(), fetcher.lastHeaders) + } } private class MemoryDataSourceByteStore( diff --git a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt index 2a148bf4..fd7c9a88 100644 --- a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt +++ b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt @@ -48,13 +48,20 @@ public class KtorRemoteDataSourceFetcher( */ public class JvmDataSourceResolver( cacheDir: File = defaultCacheDir(), - fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher() + fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher(), + huggingFaceToken: DataSourceAuthToken? = null, + useEnvironmentHuggingFaceToken: Boolean = false ) : DataSourceResolver { private val delegate = DefaultDataSourceResolver( store = FileSystemDataSourceArtifactStore(Path(cacheDir.absolutePath)), fetcher = fetcher, checksum = JvmSha256DataSourceChecksum, - headerProvider = JvmHuggingFaceHeaderProvider + headerProvider = HuggingFaceTokenHeaderProvider( + JvmHuggingFaceTokenProvider( + configuredToken = huggingFaceToken, + useEnvironmentToken = useEnvironmentHuggingFaceToken + ) + ) ) override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact = withContext(Dispatchers.IO) { @@ -70,15 +77,16 @@ public class JvmDataSourceResolver( } } -internal object JvmHuggingFaceHeaderProvider : DataSourceHeaderProvider { - override fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map { - if (parsedUri.provider != DataSourceProvider.HuggingFace) return request.headers - if (request.headers.keys.any { it.equals("Authorization", ignoreCase = true) }) return request.headers - val token = System.getenv("HF_TOKEN") - ?.takeIf { it.isNotBlank() } - ?: System.getenv("HUGGING_FACE_HUB_TOKEN")?.takeIf { it.isNotBlank() } - ?: return request.headers - return request.headers + ("Authorization" to "Bearer $token") +internal class JvmHuggingFaceTokenProvider( + private val configuredToken: DataSourceAuthToken?, + private val useEnvironmentToken: Boolean +) : HuggingFaceTokenProvider { + override fun token(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): DataSourceAuthToken? { + if (parsedUri.provider != DataSourceProvider.HuggingFace) return null + configuredToken?.let { return it } + if (!useEnvironmentToken) return null + return DataSourceAuthToken.fromOrNull(System.getenv("HF_TOKEN")) + ?: DataSourceAuthToken.fromOrNull(System.getenv("HUGGING_FACE_HUB_TOKEN")) } } diff --git a/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt index f0102d38..4f997cd4 100644 --- a/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt +++ b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt @@ -134,6 +134,51 @@ class JvmDataSourceResolverTest { root.deleteRecursively() } } + + @Test + fun sendsConfiguredHuggingFaceToken() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = FakeFetcher("payload".encodeToByteArray()) + val resolver = JvmDataSourceResolver( + cacheDir = root.resolve("cache"), + fetcher = fetcher, + huggingFaceToken = DataSourceAuthToken.from("hf_configured"), + useEnvironmentHuggingFaceToken = false + ) + + resolver.resolve(DataSourceRequest("hf://org/repo@main/file.bin")) + + assertEquals(mapOf("Authorization" to "Bearer hf_configured"), fetcher.lastHeaders) + } finally { + root.deleteRecursively() + } + } + + @Test + fun requestHuggingFaceTokenOverridesConfiguredToken() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = FakeFetcher("payload".encodeToByteArray()) + val resolver = JvmDataSourceResolver( + cacheDir = root.resolve("cache"), + fetcher = fetcher, + huggingFaceToken = DataSourceAuthToken.from("hf_configured"), + useEnvironmentHuggingFaceToken = false + ) + + resolver.resolve( + DataSourceRequest( + uri = "hf://org/repo@main/file.bin", + huggingFaceToken = DataSourceAuthToken.from("hf_request") + ) + ) + + assertEquals(mapOf("Authorization" to "Bearer hf_request"), fetcher.lastHeaders) + } finally { + root.deleteRecursively() + } + } } private class FakeFetcher( @@ -142,8 +187,12 @@ private class FakeFetcher( var calls: Int = 0 private set + var lastHeaders: Map = emptyMap() + private set + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { calls++ + lastHeaders = headers return DataSourceRemoteContent.fromBytes(bytes) } }