diff --git a/README.md b/README.md index 59beabfe..56eda582 100644 --- a/README.md +++ b/README.md @@ -212,6 +212,7 @@ Runnable examples: ### Data and I/O - Built-in loaders: MNIST, Fashion-MNIST, CIFAR-10 +- URI-backed data sources: `file://`, `https://`, `hf+https://`, and `hf://...` - Formats: GGUF, ONNX, SafeTensors, JSON, Image (JPEG, PNG) - Type-safe transform DSL: resize, crop, normalize, toTensor diff --git a/build.gradle.kts b/build.gradle.kts index 771b2133..d76ff097 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -151,6 +151,7 @@ dependencies { // skainet-data dokka(project(":skainet-data:skainet-data-api")) + dokka(project(":skainet-data:skainet-data-source")) dokka(project(":skainet-data:skainet-data-transform")) dokka(project(":skainet-data:skainet-data-simple")) dokka(project(":skainet-data:skainet-data-media")) @@ -178,4 +179,4 @@ tasks.register("bundleDokkaIntoSite") { dependsOn("dokkaGenerate") from(layout.buildDirectory.dir("dokka/html")) into(layout.projectDirectory.dir("docs/build/site/api")) -} \ No newline at end of file +} diff --git a/docs/modules/ROOT/nav.adoc b/docs/modules/ROOT/nav.adoc index 1c8ed540..e6a714cd 100644 --- a/docs/modules/ROOT/nav.adoc +++ b/docs/modules/ROOT/nav.adoc @@ -5,6 +5,7 @@ * Tutorials ** xref:tutorials/kotlin-getting-started.adoc[Kotlin getting started] ** xref:tutorials/java-getting-started.adoc[Java getting started] +** xref:tutorials/data-sources-getting-started.adoc[Data sources and Hugging Face] ** xref:tutorials/image-data-getting-started.adoc[Image and data API] ** xref:tutorials/hlo-getting-started.adoc[StableHLO getting started] ** xref:tutorials/minerva-getting-started.adoc[Minerva getting started] diff --git a/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc new file mode 100644 index 00000000..4557e875 --- /dev/null +++ b/docs/modules/ROOT/pages/tutorials/data-sources-getting-started.adoc @@ -0,0 +1,153 @@ +== Data sources and Hugging Face + +SKaiNET separates artifact resolution from dataset parsing and preprocessing. +Use `skainet-data-source` when a dataset, tokenizer, model sidecar, or fixture +can live either on disk or behind a remote URI. + +[cols="1,3",options="header"] +|=== +| URI form | Meaning +| `file:///path/to/file` +| Read a local file. + +| `https://host/path/file` +| Download and cache a generic remote artifact. + +| `hf+https://huggingface.co/org/repo/resolve/main/file` +| Treat a Hugging Face resolve URL as a Hugging Face artifact. + +| `hf://org/repo@main/path/file` +| Expand to a Hugging Face model repository resolve URL. + +| `hf://datasets/org/repo@main/path/file` +| Expand to a Hugging Face dataset repository resolve URL. +|=== + +=== Add the modules + +For JVM consumers, add the source module beside the data loaders you use: + +[source,kotlin] +---- +dependencies { + implementation(platform("sk.ainet:skainet-bom:0.32.4")) + + implementation("sk.ainet.core:skainet-data-source-jvm") + implementation("sk.ainet.core:skainet-data-simple-jvm") +} +---- + +=== Resolve one artifact + +`JvmDataSourceResolver` materializes remote artifacts into a cache and returns +a `DataSourceArtifact` that opens a `kotlinx.io.Source`. Public Hugging Face +files do not need credentials. Private files should pass an explicit +`DataSourceAuthToken` on the request or resolver. Existing `Authorization` +headers still take precedence. On JVM, the resolver can also read `HF_TOKEN` / +`HUGGING_FACE_HUB_TOKEN` from the environment as an opt-in convenience fallback. + +[source,kotlin] +---- +import sk.ainet.data.source.DataSourceAuthToken +import sk.ainet.data.source.DataSourceRequest +import sk.ainet.data.source.JvmDataSourceResolver + +val resolver = JvmDataSourceResolver( + huggingFaceToken = DataSourceAuthToken.from("hf_...") +) +val artifact = resolver.resolve( + DataSourceRequest( + uri = "hf+https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json" + ) +) + +println(artifact.filename) +println(artifact.localPath) + +val source = artifact.openSource() +try { + // Pass the source to a parser/loader for model-sized artifacts. +} finally { + source.close() +} + +// Convenience for small sidecars and tests. +val bytes = artifact.readBytes() +---- + +For per-request credentials, pass the token directly on `DataSourceRequest`. +This is useful when one resolver works with more than one private repository: + +[source,kotlin] +---- +val privateArtifact = resolver.resolve( + DataSourceRequest( + uri = "hf://datasets/your-org/private-dataset@main/data/train.bin", + huggingFaceToken = DataSourceAuthToken.from("hf_...") + ) +) +---- + +To opt into JVM environment fallback: + +[source,kotlin] +---- +val resolver = JvmDataSourceResolver( + useEnvironmentHuggingFaceToken = true +) +---- + +=== Use sources with built-in loaders + +MNIST and Fashion-MNIST expose per-file URI overrides. CIFAR-10 exposes an +archive URI override. Defaults still point to the historical public dataset +locations, so existing code keeps working. + +[source,kotlin] +---- +import sk.ainet.data.mnist.MNIST +import sk.ainet.data.mnist.MNISTLoaderConfig + +val token = "hf_..." +val train = MNIST.loadTrain( + MNISTLoaderConfig( + trainImagesUri = "file:///datasets/mnist/train-images-idx3-ubyte", + trainLabelsUri = "hf+https://huggingface.co/your-org/mnist-idx/resolve/main/train-labels-idx1-ubyte.gz", + huggingFaceTokenProvider = { token } + ) +) + +val batches = train.batchIterator(batchSize = 64) +---- + +=== Cache behavior + +Use `CachePolicy.Use` for normal operation, `Refresh` to re-download, +`Offline` to require a cached copy, and `Bypass` to avoid writing the cache. +Built-in JVM loaders map `useCache = true` to `Use` and `useCache = false` +to `Refresh`. + +[source,kotlin] +---- +import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceRequest + +val refreshed = resolver.resolve( + DataSourceRequest( + uri = "hf://datasets/your-org/your-dataset@main/data/train-00000.parquet", + cachePolicy = CachePolicy.Refresh + ) +) +---- + +=== Keep preprocessing separate + +After bytes are parsed into a dataset, continue using the existing transform +DSL for image/tensor preprocessing: + +[source,kotlin] +---- +import sk.ainet.data.transform.mnistPreprocessing + +val preprocessing = mnistPreprocessing(ctx) +---- diff --git a/settings.gradle.kts b/settings.gradle.kts index bbfbc825..5e25ef39 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -48,6 +48,7 @@ include("skainet-backends:benchmarks:jvm-cpu-publish") // ====== DATA include("skainet-data:skainet-data-api") +include("skainet-data:skainet-data-source") include("skainet-data:skainet-data-transform") include("skainet-data:skainet-data-simple") include("skainet-data:skainet-data-media") diff --git a/skainet-data/skainet-data-simple/build.gradle.kts b/skainet-data/skainet-data-simple/build.gradle.kts index 562bedd6..9bd671f0 100644 --- a/skainet-data/skainet-data-simple/build.gradle.kts +++ b/skainet-data/skainet-data-simple/build.gradle.kts @@ -62,6 +62,7 @@ kotlin { } jvmMain.dependencies { + implementation(project(":skainet-data:skainet-data-source")) implementation(libs.ktor.client.cio) implementation(libs.ktor.client.plugins) implementation(libs.ktor.client.logging) diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt index f2a1e106..1c63b86f 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/cifar10/CIFAR10Data.kt @@ -6,6 +6,7 @@ import sk.ainet.context.DefaultDataExecutionContext import sk.ainet.context.ExecutionContext import sk.ainet.data.DataBatch import sk.ainet.data.Dataset +import sk.ainet.data.common.DatasetHuggingFaceTokenProvider import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.types.DType @@ -144,7 +145,10 @@ public data class CIFAR10Dataset( */ public data class CIFAR10LoaderConfig( val cacheDir: String = "cifar10-data", - val useCache: Boolean = true + val useCache: Boolean = true, + val archiveUri: String = CIFAR10Constants.DOWNLOAD_URL, + val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null, + val useEnvironmentHuggingFaceToken: Boolean = false ) /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetHuggingFaceTokenProvider.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetHuggingFaceTokenProvider.kt new file mode 100644 index 00000000..1efeacd1 --- /dev/null +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/common/DatasetHuggingFaceTokenProvider.kt @@ -0,0 +1,9 @@ +package sk.ainet.data.common + +/** + * Supplies a Hugging Face token for built-in dataset loaders when their source + * URIs point at private Hugging Face artifacts. + */ +public fun interface DatasetHuggingFaceTokenProvider { + public fun token(): String? +} diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt index b9227418..c9286506 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTData.kt @@ -6,6 +6,7 @@ import sk.ainet.context.DefaultDataExecutionContext import sk.ainet.context.ExecutionContext import sk.ainet.data.DataBatch import sk.ainet.data.Dataset +import sk.ainet.data.common.DatasetHuggingFaceTokenProvider import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.types.DType @@ -146,7 +147,13 @@ public data class FashionMNISTDataset( */ public data class FashionMNISTLoaderConfig( val cacheDir: String = "fashion-mnist-data", - val useCache: Boolean = true + val useCache: Boolean = true, + val trainImagesUri: String = FashionMNISTConstants.TRAIN_IMAGES_URL, + val trainLabelsUri: String = FashionMNISTConstants.TRAIN_LABELS_URL, + val testImagesUri: String = FashionMNISTConstants.TEST_IMAGES_URL, + val testLabelsUri: String = FashionMNISTConstants.TEST_LABELS_URL, + val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null, + val useEnvironmentHuggingFaceToken: Boolean = false ) /** diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderCommon.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderCommon.kt index 45888031..9fff81fd 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderCommon.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderCommon.kt @@ -16,11 +16,11 @@ public abstract class FashionMNISTLoaderCommon(public val config: FashionMNISTLo */ override suspend fun loadTrainingData(): FashionMNISTDataset { val imagesBytes = downloadAndCacheFile( - FashionMNISTConstants.TRAIN_IMAGES_URL, + config.trainImagesUri, FashionMNISTConstants.TRAIN_IMAGES_FILENAME ) val labelsBytes = downloadAndCacheFile( - FashionMNISTConstants.TRAIN_LABELS_URL, + config.trainLabelsUri, FashionMNISTConstants.TRAIN_LABELS_FILENAME ) @@ -34,11 +34,11 @@ public abstract class FashionMNISTLoaderCommon(public val config: FashionMNISTLo */ override suspend fun loadTestData(): FashionMNISTDataset { val imagesBytes = downloadAndCacheFile( - FashionMNISTConstants.TEST_IMAGES_URL, + config.testImagesUri, FashionMNISTConstants.TEST_IMAGES_FILENAME ) val labelsBytes = downloadAndCacheFile( - FashionMNISTConstants.TEST_LABELS_URL, + config.testLabelsUri, FashionMNISTConstants.TEST_LABELS_FILENAME ) diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt index d120cdd6..ae6881cf 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTData.kt @@ -6,6 +6,7 @@ import sk.ainet.context.DefaultDataExecutionContext import sk.ainet.context.ExecutionContext import sk.ainet.data.DataBatch import sk.ainet.data.Dataset +import sk.ainet.data.common.DatasetHuggingFaceTokenProvider import sk.ainet.lang.tensor.Shape import sk.ainet.lang.tensor.Tensor import sk.ainet.lang.types.DType @@ -124,7 +125,13 @@ public data class MNISTDataset( */ public data class MNISTLoaderConfig( val cacheDir: String = "mnist-data", - val useCache: Boolean = true + val useCache: Boolean = true, + val trainImagesUri: String = MNISTConstants.TRAIN_IMAGES_URL, + val trainLabelsUri: String = MNISTConstants.TRAIN_LABELS_URL, + val testImagesUri: String = MNISTConstants.TEST_IMAGES_URL, + val testLabelsUri: String = MNISTConstants.TEST_LABELS_URL, + val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null, + val useEnvironmentHuggingFaceToken: Boolean = false ) /** @@ -164,4 +171,4 @@ public interface MNISTLoader { * @return The MNIST test dataset. */ public suspend fun loadTestData(): MNISTDataset -} \ No newline at end of file +} diff --git a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTLoaderCommon.kt b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTLoaderCommon.kt index 4c334d31..66ef8085 100644 --- a/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTLoaderCommon.kt +++ b/skainet-data/skainet-data-simple/src/commonMain/kotlin/sk/ainet/data/mnist/MNISTLoaderCommon.kt @@ -14,11 +14,11 @@ public abstract class MNISTLoaderCommon(public val config: MNISTLoaderConfig) : */ override suspend fun loadTrainingData(): MNISTDataset { val imagesBytes = downloadAndCacheFile( - MNISTConstants.TRAIN_IMAGES_URL, + config.trainImagesUri, MNISTConstants.TRAIN_IMAGES_FILENAME ) val labelsBytes = downloadAndCacheFile( - MNISTConstants.TRAIN_LABELS_URL, + config.trainLabelsUri, MNISTConstants.TRAIN_LABELS_FILENAME ) @@ -32,11 +32,11 @@ public abstract class MNISTLoaderCommon(public val config: MNISTLoaderConfig) : */ override suspend fun loadTestData(): MNISTDataset { val imagesBytes = downloadAndCacheFile( - MNISTConstants.TEST_IMAGES_URL, + config.testImagesUri, MNISTConstants.TEST_IMAGES_FILENAME ) val labelsBytes = downloadAndCacheFile( - MNISTConstants.TEST_LABELS_URL, + config.testLabelsUri, MNISTConstants.TEST_LABELS_FILENAME ) diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt index 82212e3f..5ecf29ee 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/cifar10/CIFAR10LoaderJvm.kt @@ -1,19 +1,11 @@ package sk.ainet.data.cifar10 -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.logging.Logging -import io.ktor.client.plugins.HttpTimeout -import io.ktor.client.request.get -import io.ktor.client.statement.HttpResponse -import io.ktor.client.call.body import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext +import sk.ainet.data.common.JvmDatasetSourceReader +import sk.ainet.data.common.gunzip import java.io.File -import java.io.FileInputStream import java.io.FileOutputStream -import java.io.ByteArrayInputStream -import java.util.zip.GZIPInputStream /** * JVM implementation of the CIFAR-10 loader. @@ -23,6 +15,12 @@ import java.util.zip.GZIPInputStream * @property config The configuration for the CIFAR-10 loader. */ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon(config) { + private val sources = JvmDatasetSourceReader( + cacheDir = config.cacheDir, + useCache = config.useCache, + huggingFaceTokenProvider = config.huggingFaceTokenProvider, + useEnvironmentHuggingFaceToken = config.useEnvironmentHuggingFaceToken + ) /** * Downloads the CIFAR-10 archive and extracts the specified batch file. @@ -45,21 +43,10 @@ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon return@withContext batchFile.readBytes() } - // Check if we need to download and extract the archive + // Check if we need to resolve and extract the archive if (!extractedDir.exists() || !config.useCache) { - val archiveFile = File(cacheDir, CIFAR10Constants.ARCHIVE_FILENAME) - - // Download if not cached - if (!archiveFile.exists() || !config.useCache) { - println("Downloading CIFAR-10 archive: ${CIFAR10Constants.DOWNLOAD_URL}") - downloadFile(CIFAR10Constants.DOWNLOAD_URL, archiveFile.path) - } else { - println("Using cached archive: ${archiveFile.path}") - } - - // Extract the archive println("Extracting CIFAR-10 archive...") - extractTarGz(archiveFile.path, cacheDir.path) + extractTarGz(sources.read(config.archiveUri), cacheDir.path) } if (!batchFile.exists()) { @@ -69,53 +56,15 @@ public class CIFAR10LoaderJvm(config: CIFAR10LoaderConfig) : CIFAR10LoaderCommon return@withContext batchFile.readBytes() } - /** - * Downloads a file from a URL. - * - * @param url The URL to download from. - * @param outputPath The path to save the file to. - */ - private suspend fun downloadFile(url: String, outputPath: String) { - val client = HttpClient(CIO) { - install(Logging) - - // Configure timeout for large files (CIFAR-10 is ~170MB) - install(HttpTimeout) { - requestTimeoutMillis = 600000 // 10 minutes - connectTimeoutMillis = 60000 // 60 seconds - socketTimeoutMillis = 600000 // 10 minutes - } - } - - try { - val file = File(outputPath) - - val httpResponse: HttpResponse = client.get(url) - val responseBody: ByteArray = httpResponse.body() - file.writeBytes(responseBody) - - println("File saved to ${file.path} (${responseBody.size} bytes)") - } finally { - client.close() - } - } - /** * Extracts a .tar.gz archive using a simple TAR parser. * - * @param archivePath The path to the .tar.gz file. + * @param archiveBytes The bytes of the .tar.gz file. * @param outputDir The directory to extract files to. */ - private fun extractTarGz(archivePath: String, outputDir: String) { + private fun extractTarGz(archiveBytes: ByteArray, outputDir: String) { val outputDirFile = File(outputDir) - - // First, decompress gzip to get the tar content - val tarBytes = GZIPInputStream(FileInputStream(archivePath)).use { gzipIn -> - gzipIn.readBytes() - } - - // Parse the TAR archive - extractTar(tarBytes, outputDirFile) + extractTar(archiveBytes.gunzip(), outputDirFile) } /** diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt new file mode 100644 index 00000000..785910fd --- /dev/null +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/common/JvmDatasetSourceReader.kt @@ -0,0 +1,47 @@ +package sk.ainet.data.common + +import sk.ainet.data.source.CachePolicy +import sk.ainet.data.source.DataSourceAuthToken +import sk.ainet.data.source.DataSourceRequest +import sk.ainet.data.source.JvmDataSourceResolver +import java.io.ByteArrayInputStream +import java.io.File +import java.util.zip.GZIPInputStream + +internal class JvmDatasetSourceReader( + cacheDir: String, + useCache: Boolean, + private val huggingFaceTokenProvider: DatasetHuggingFaceTokenProvider? = null, + useEnvironmentHuggingFaceToken: Boolean = false +) { + private val resolver = JvmDataSourceResolver( + cacheDir = File(cacheDir, "sources"), + useEnvironmentHuggingFaceToken = useEnvironmentHuggingFaceToken + ) + private val cachePolicy = if (useCache) CachePolicy.Use else CachePolicy.Refresh + + suspend fun read(uri: String): ByteArray { + val artifact = resolver.resolve( + DataSourceRequest( + uri = uri, + cachePolicy = cachePolicy, + huggingFaceToken = DataSourceAuthToken.fromOrNull(huggingFaceTokenProvider?.token()) + ) + ) + return artifact.readBytes() + } + + suspend fun readGzipDecoded(uri: String): ByteArray = read(uri).gunzipIfNeeded() +} + +internal fun ByteArray.gunzip(): ByteArray { + return GZIPInputStream(ByteArrayInputStream(this)).use { it.readBytes() } +} + +internal fun ByteArray.gunzipIfNeeded(): ByteArray { + return if (isGzip()) gunzip() else this +} + +private fun ByteArray.isGzip(): Boolean { + return size >= 2 && this[0] == 0x1f.toByte() && this[1] == 0x8b.toByte() +} diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt index 2444a779..27c598e6 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/fashionmnist/FashionMNISTLoaderJvm.kt @@ -1,18 +1,6 @@ package sk.ainet.data.fashionmnist -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.logging.Logging -import io.ktor.client.plugins.HttpTimeout -import io.ktor.client.request.get -import io.ktor.client.statement.HttpResponse -import io.ktor.client.call.body -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import java.io.File -import java.io.FileInputStream -import java.io.FileOutputStream -import java.util.zip.GZIPInputStream +import sk.ainet.data.common.JvmDatasetSourceReader /** * JVM implementation of the Fashion-MNIST loader. @@ -20,91 +8,22 @@ import java.util.zip.GZIPInputStream * @property config The configuration for the Fashion-MNIST loader. */ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMNISTLoaderCommon(config) { + private val sources = JvmDatasetSourceReader( + cacheDir = config.cacheDir, + useCache = config.useCache, + huggingFaceTokenProvider = config.huggingFaceTokenProvider, + useEnvironmentHuggingFaceToken = config.useEnvironmentHuggingFaceToken + ) /** - * Downloads and caches a file. + * Resolves, caches, and decompresses a file when needed. * * @param url The URL to download from. * @param filename The name of the file to save. * @return The bytes of the decompressed file. */ - override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.IO) { - val cacheDir = File(config.cacheDir) - if (!cacheDir.exists()) { - cacheDir.mkdirs() - } - - val gzipFile = File(cacheDir, filename) - val decompressedFile = File(cacheDir, filename.removeSuffix(".gz")) - - // Check if the decompressed file already exists in cache - if (config.useCache && decompressedFile.exists()) { - println("Using cached file: ${decompressedFile.path}") - return@withContext decompressedFile.readBytes() - } - - // Check if the gzip file already exists in cache - if (!gzipFile.exists() || !config.useCache) { - println("Downloading Fashion-MNIST file: $url") - downloadFile(url, gzipFile.path) - } else { - println("Using cached gzip file: ${gzipFile.path}") - } - - // Decompress the gzip file - println("Decompressing file: ${gzipFile.path}") - decompressGzipFile(gzipFile.path, decompressedFile.path) - - return@withContext decompressedFile.readBytes() - } - - /** - * Downloads a file from a URL. - * - * @param url The URL to download from. - * @param outputPath The path to save the file to. - */ - private suspend fun downloadFile(url: String, outputPath: String) { - val client = HttpClient(CIO) { - install(Logging) - - // Configure timeout for large files - install(HttpTimeout) { - requestTimeoutMillis = 60000 // 60 seconds - connectTimeoutMillis = 60000 // 60 seconds - socketTimeoutMillis = 60000 // 60 seconds - } - } - - try { - val file = File(outputPath) - - val httpResponse: HttpResponse = client.get(url) - val responseBody: ByteArray = httpResponse.body() - file.writeBytes(responseBody) - - println("File saved to ${file.path}") - } finally { - client.close() - } - } - - /** - * Decompresses a gzip file. - * - * @param gzipFilePath The path to the gzip file. - * @param outputFilePath The path to save the decompressed file to. - */ - private fun decompressGzipFile(gzipFilePath: String, outputFilePath: String) { - GZIPInputStream(FileInputStream(gzipFilePath)).use { gzipInputStream -> - FileOutputStream(outputFilePath).use { outputStream -> - val buffer = ByteArray(1024) - var len: Int - while (gzipInputStream.read(buffer).also { len = it } > 0) { - outputStream.write(buffer, 0, len) - } - } - } + override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray { + return sources.readGzipDecoded(url) } public companion object { @@ -137,4 +56,5 @@ public class FashionMNISTLoaderJvm(config: FashionMNISTLoaderConfig) : FashionMN return FashionMNISTLoaderJvm(config) } } + } diff --git a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt index b6bcb9aa..0457d66c 100644 --- a/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt +++ b/skainet-data/skainet-data-simple/src/jvmMain/kotlin/sk/ainet/data/mnist/MNISTLoaderJvm.kt @@ -1,18 +1,6 @@ package sk.ainet.data.mnist -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.logging.Logging -import io.ktor.client.plugins.HttpTimeout -import io.ktor.client.request.get -import io.ktor.client.statement.HttpResponse -import io.ktor.client.call.body -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import java.io.File -import java.io.FileInputStream -import java.io.FileOutputStream -import java.util.zip.GZIPInputStream +import sk.ainet.data.common.JvmDatasetSourceReader /** * JVM implementation of the MNIST loader. @@ -20,91 +8,22 @@ import java.util.zip.GZIPInputStream * @property config The configuration for the MNIST loader. */ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(config) { + private val sources = JvmDatasetSourceReader( + cacheDir = config.cacheDir, + useCache = config.useCache, + huggingFaceTokenProvider = config.huggingFaceTokenProvider, + useEnvironmentHuggingFaceToken = config.useEnvironmentHuggingFaceToken + ) /** - * Downloads and caches a file. + * Resolves, caches, and decompresses a file when needed. * * @param url The URL to download from. * @param filename The name of the file to save. * @return The bytes of the decompressed file. */ - override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray = withContext(Dispatchers.IO) { - val cacheDir = File(config.cacheDir) - if (!cacheDir.exists()) { - cacheDir.mkdirs() - } - - val gzipFile = File(cacheDir, filename) - val decompressedFile = File(cacheDir, filename.removeSuffix(".gz")) - - // Check if the decompressed file already exists in cache - if (config.useCache && decompressedFile.exists()) { - println("Using cached file: ${decompressedFile.path}") - return@withContext decompressedFile.readBytes() - } - - // Check if the gzip file already exists in cache - if (!gzipFile.exists() || !config.useCache) { - println("Downloading file: $url") - downloadFile(url, gzipFile.path) - } else { - println("Using cached gzip file: ${gzipFile.path}") - } - - // Decompress the gzip file - println("Decompressing file: ${gzipFile.path}") - decompressGzipFile(gzipFile.path, decompressedFile.path) - - return@withContext decompressedFile.readBytes() - } - - /** - * Downloads a file from a URL. - * - * @param url The URL to download from. - * @param outputPath The path to save the file to. - */ - private suspend fun downloadFile(url: String, outputPath: String) { - val client = HttpClient(CIO) { - install(Logging) - - // Configure timeout for large files - install(HttpTimeout) { - requestTimeoutMillis = 300000 // 5 minutes - connectTimeoutMillis = 60000 // 60 seconds - socketTimeoutMillis = 300000 // 5 minutes - } - } - - try { - val file = File(outputPath) - - val httpResponse: HttpResponse = client.get(url) - val responseBody: ByteArray = httpResponse.body() - file.writeBytes(responseBody) - - println("File saved to ${file.path}") - } finally { - client.close() - } - } - - /** - * Decompresses a gzip file. - * - * @param gzipFilePath The path to the gzip file. - * @param outputFilePath The path to save the decompressed file to. - */ - private fun decompressGzipFile(gzipFilePath: String, outputFilePath: String) { - GZIPInputStream(FileInputStream(gzipFilePath)).use { gzipInputStream -> - FileOutputStream(outputFilePath).use { outputStream -> - val buffer = ByteArray(1024) - var len: Int - while (gzipInputStream.read(buffer).also { len = it } > 0) { - outputStream.write(buffer, 0, len) - } - } - } + override suspend fun downloadAndCacheFile(url: String, filename: String): ByteArray { + return sources.readGzipDecoded(url) } public companion object { @@ -137,4 +56,5 @@ public class MNISTLoaderJvm(config: MNISTLoaderConfig) : MNISTLoaderCommon(confi return MNISTLoaderJvm(config) } } + } diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt index 06b23b39..541516b3 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/cifar10/CIFAR10LoaderTest.kt @@ -60,7 +60,8 @@ class CIFAR10LoaderTest { fun testLoaderConfiguration() { val config = CIFAR10LoaderConfig( cacheDir = "custom-cache-dir", - useCache = false + useCache = false, + archiveUri = "hf+https://huggingface.co/datasets/cifar10/resolve/main/cifar-10-binary.tar.gz" ) val loader = createCIFAR10Loader(config) diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt index 5f80dc31..30135890 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/fashionmnist/FashionMNISTLoaderTest.kt @@ -60,7 +60,9 @@ class FashionMNISTLoaderTest { fun testLoaderConfiguration() { val config = FashionMNISTLoaderConfig( cacheDir = "custom-cache-dir", - useCache = false + useCache = false, + trainImagesUri = "file:///datasets/fashion-mnist/train-images", + trainLabelsUri = "hf+https://huggingface.co/datasets/zalando-datasets/fashion_mnist/resolve/main/train-labels" ) val loader = createFashionMNISTLoader(config) diff --git a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt index 82e04d75..94796645 100644 --- a/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt +++ b/skainet-data/skainet-data-simple/src/jvmTest/kotlin/sk/ainet/io/data/mnist/MNISTLoaderTest.kt @@ -6,6 +6,7 @@ import sk.ainet.data.mnist.MNISTImage import sk.ainet.data.mnist.MNISTLoaderConfig import sk.ainet.data.mnist.MNISTLoaderFactory import sk.ainet.data.mnist.MNISTLoaderCommon +import java.nio.file.Files import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertNotNull @@ -60,12 +61,37 @@ class MNISTLoaderTest { fun testLoaderConfiguration() { val config = MNISTLoaderConfig( cacheDir = "custom-cache-dir", - useCache = false + useCache = false, + trainImagesUri = "file:///datasets/mnist/train-images", + trainLabelsUri = "hf+https://huggingface.co/datasets/mnist/mnist/resolve/main/train-labels" ) val loader = MNISTLoaderFactory.create(config) assertNotNull(loader) } + + @Test + fun testJvmLoaderReadsConfiguredFileUris() = runBlocking { + val root = Files.createTempDirectory("skainet-mnist-loader-test").toFile() + try { + val trainImages = root.resolve("train-images.idx") + val trainLabels = root.resolve("train-labels.idx") + trainImages.writeBytes(TRAINING_IMAGES_BYTES) + trainLabels.writeBytes(TRAINING_LABELS_BYTES) + val config = MNISTLoaderConfig( + cacheDir = root.resolve("cache").absolutePath, + useCache = false, + trainImagesUri = trainImages.toURI().toString(), + trainLabelsUri = trainLabels.toURI().toString() + ) + + val dataset = MNISTLoaderFactory.create(config).loadTrainingData() + + assertEquals(EXPECTED_TRAINING_DATA, dataset.images) + } finally { + root.deleteRecursively() + } + } } private fun createFakeLoader(config: MNISTLoaderConfig = MNISTLoaderConfig()): FakeMNISTLoader { diff --git a/skainet-data/skainet-data-source/build.gradle.kts b/skainet-data/skainet-data-source/build.gradle.kts new file mode 100644 index 00000000..7d7c3065 --- /dev/null +++ b/skainet-data/skainet-data-source/build.gradle.kts @@ -0,0 +1,37 @@ +import org.jetbrains.kotlin.gradle.dsl.JvmTarget + +plugins { + alias(libs.plugins.kotlinMultiplatform) + alias(libs.plugins.vanniktech.mavenPublish) + id("sk.ainet.dokka") +} + +kotlin { + explicitApi() + + jvm { + compilerOptions { + jvmTarget.set(JvmTarget.JVM_11) + } + } + + sourceSets { + commonMain.dependencies { + implementation(libs.kotlinx.coroutines) + implementation(libs.kotlinx.io.core) + } + + commonTest.dependencies { + implementation(libs.kotlin.test) + implementation(libs.kotlinx.coroutines.test) + } + + jvmMain.dependencies { + implementation(libs.ktor.client.cio) + implementation(libs.ktor.client.core) + implementation(libs.ktor.client.plugins) + implementation(libs.kotlinx.coroutines.core.jvm) + } + + } +} diff --git a/skainet-data/skainet-data-source/gradle.properties b/skainet-data/skainet-data-source/gradle.properties new file mode 100644 index 00000000..3516f9dd --- /dev/null +++ b/skainet-data/skainet-data-source/gradle.properties @@ -0,0 +1,2 @@ +POM_ARTIFACT_ID=skainet-data-source +POM_NAME=skainet data source diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt new file mode 100644 index 00000000..12077c0b --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceModels.kt @@ -0,0 +1,163 @@ +package sk.ainet.data.source + +import kotlinx.io.Sink +import kotlinx.io.Source +import kotlinx.io.readByteArray + +/** + * Cache behavior requested by a caller resolving a data artifact. + */ +public enum class CachePolicy { + /** Use a cached artifact when present, otherwise fetch and cache it. */ + Use, + + /** Fetch the artifact again and replace any existing cached copy. */ + Refresh, + + /** Require an existing cached or local artifact; do not use the network. */ + Offline, + + /** Fetch or read the artifact without writing a persistent cache entry. */ + Bypass +} + +/** + * High-level provider implied by a source URI. + */ +public enum class DataSourceProvider { + File, + Http, + HuggingFace +} + +/** + * Hugging Face repository namespace encoded by an `hf://` URI. + */ +public enum class HuggingFaceRepoType { + Model, + Dataset, + Space +} + +/** + * Parsed Hugging Face location, when a URI uses SKaiNET's `hf://` shorthand + * or the explicit `hf+https://...` provider prefix. + */ +public data class HuggingFaceLocation( + public val repoType: HuggingFaceRepoType, + public val repoId: String?, + public val revision: String?, + public val path: String? +) + +/** + * Authentication token for provider-specific data source requests. + * + * The raw value is intentionally hidden from [toString] output so tokens are + * not leaked when requests or configs are logged. + */ +public class DataSourceAuthToken private constructor( + private val value: String +) { + override fun toString(): String = "DataSourceAuthToken(***)" + + internal fun authorizationHeaderValue(): String = "Bearer $value" + + public companion object { + public fun from(value: String): DataSourceAuthToken { + val normalized = value.trim() + require(normalized.isNotEmpty()) { "Data source auth token cannot be blank" } + return DataSourceAuthToken(normalized) + } + + public fun fromOrNull(value: String?): DataSourceAuthToken? { + val normalized = value?.trim()?.takeIf { it.isNotEmpty() } ?: return null + return DataSourceAuthToken(normalized) + } + } +} + +/** + * A normalized, provider-aware source URI. + */ +public data class ParsedDataSourceUri( + public val rawUri: String, + public val provider: DataSourceProvider, + public val transportUri: String, + public val filename: String, + public val cacheKey: String, + public val localPath: String? = null, + public val huggingFace: HuggingFaceLocation? = null +) + +/** + * Request to resolve a local or remote artifact. + */ +public data class DataSourceRequest( + public val uri: String, + public val cachePolicy: CachePolicy = CachePolicy.Use, + public val expectedSha256: String? = null, + public val headers: Map = emptyMap(), + public val huggingFaceToken: DataSourceAuthToken? = null +) + +/** + * A resolved artifact. Remote artifacts may expose a [localPath] when they + * have been materialized into a platform cache. + */ +public class DataSourceArtifact( + public val request: DataSourceRequest, + public val parsedUri: ParsedDataSourceUri, + public val filename: String, + public val localPath: String?, + public val sizeBytes: Long?, + public val cacheHit: Boolean, + private val sourceOpener: suspend () -> Source +) { + /** + * Opens a fresh source for this artifact. Callers own and must close it. + */ + public suspend fun openSource(): Source = sourceOpener() + + /** + * Convenience for small artifacts. Prefer [openSource] or [copyTo] for + * model-scale data. + */ + public suspend fun readBytes(): ByteArray { + val source = openSource() + return try { + source.readByteArray() + } finally { + source.close() + } + } + + /** + * Streams this artifact into [sink]. The source is closed after copying; + * [sink] is left open for the caller. + */ + public suspend fun copyTo(sink: Sink): Long { + val source = openSource() + return try { + source.transferTo(sink) + } finally { + source.close() + } + } +} + +/** + * Resolves source URIs into readable data artifacts. + */ +public interface DataSourceResolver { + public suspend fun resolve(request: DataSourceRequest): DataSourceArtifact +} + +public open class DataSourceException( + message: String, + cause: Throwable? = null +) : RuntimeException(message, cause) + +public class UnsupportedDataSourceUriException( + uri: String +) : DataSourceException("Unsupported data source URI: $uri") diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt new file mode 100644 index 00000000..9281358a --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DataSourceUriParser.kt @@ -0,0 +1,205 @@ +package sk.ainet.data.source + +/** + * Parses SKaiNET data source URIs. + * + * Supported forms: + * - `file:///absolute/path` + * - `/absolute/or/relative/path` + * - `https://host/path` + * - `hf+https://huggingface.co/org/repo/resolve/main/file` + * - `hf://org/repo@revision/path/to/file` + * - `hf://datasets/org/repo@revision/path/to/file` + */ +public object DataSourceUriParser { + public fun parse(uri: String): ParsedDataSourceUri { + val raw = uri.trim() + require(raw.isNotEmpty()) { "Data source URI must not be blank" } + + return when { + raw.startsWith(HF_HTTPS_PREFIX) -> parseHfHttps(raw) + raw.startsWith(HF_URI_PREFIX) -> parseHfUri(raw) + raw.startsWith(FILE_URI_SCHEME) -> parseFileUri(raw) + raw.startsWith(HTTPS_PREFIX) || raw.startsWith(HTTP_PREFIX) -> parseHttp(raw) + raw.contains("://") -> throw UnsupportedDataSourceUriException(raw) + else -> parsePlainFilePath(raw) + } + } + + private fun parseFileUri(raw: String): ParsedDataSourceUri { + val localPath = normalizeFileUriPath(raw.removePrefix(FILE_URI_SCHEME)) + val filename = extractFilename(localPath) + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.File, + transportUri = raw, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.File, localPath, filename), + localPath = localPath + ) + } + + private fun parsePlainFilePath(raw: String): ParsedDataSourceUri { + val filename = extractFilename(raw) + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.File, + transportUri = raw, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.File, raw, filename), + localPath = raw + ) + } + + private fun parseHttp(raw: String): ParsedDataSourceUri { + val filename = extractFilename(raw) + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.Http, + transportUri = raw, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.Http, raw, filename) + ) + } + + private fun parseHfHttps(raw: String): ParsedDataSourceUri { + val transportUri = raw.removePrefix("hf+") + val filename = extractFilename(transportUri) + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.HuggingFace, + transportUri = transportUri, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.HuggingFace, transportUri, filename), + huggingFace = HuggingFaceLocation( + repoType = HuggingFaceRepoType.Model, + repoId = null, + revision = null, + path = null + ) + ) + } + + private fun parseHfUri(raw: String): ParsedDataSourceUri { + val body = raw.removePrefix(HF_URI_PREFIX).trim('/') + val segments = body.split('/').filter { it.isNotBlank() } + require(segments.size >= 3) { + "hf:// URI must include repo owner, repo name, and file path: $raw" + } + + val (repoType, repoStart) = when (segments.first()) { + "models", "model" -> HuggingFaceRepoType.Model to 1 + "datasets", "dataset" -> HuggingFaceRepoType.Dataset to 1 + "spaces", "space" -> HuggingFaceRepoType.Space to 1 + else -> HuggingFaceRepoType.Model to 0 + } + require(segments.size - repoStart >= 3) { + "hf:// URI must include repo owner, repo name, and file path: $raw" + } + + val owner = segments[repoStart] + val repoAndRevision = segments[repoStart + 1] + val repoName = repoAndRevision.substringBefore('@') + val revision = repoAndRevision.substringAfter('@', "main") + val filePath = segments.drop(repoStart + 2).joinToString("/") + val repoId = "$owner/$repoName" + val prefix = when (repoType) { + HuggingFaceRepoType.Model -> "" + HuggingFaceRepoType.Dataset -> "datasets/" + HuggingFaceRepoType.Space -> "spaces/" + } + val transportUri = "https://huggingface.co/$prefix$repoId/resolve/$revision/$filePath" + val filename = extractFilename(filePath) + + return ParsedDataSourceUri( + rawUri = raw, + provider = DataSourceProvider.HuggingFace, + transportUri = transportUri, + filename = filename, + cacheKey = cacheKey(DataSourceProvider.HuggingFace, transportUri, filename), + huggingFace = HuggingFaceLocation( + repoType = repoType, + repoId = repoId, + revision = revision, + path = filePath + ) + ) + } + + private fun normalizeFileUriPath(path: String): String { + val normalized = when { + path.startsWith("//localhost/") -> path.removePrefix("//localhost") + path.startsWith("///") -> path.drop(2) + path.startsWith("//") -> path.drop(1) + path.startsWith("/") -> path + else -> "/$path" + } + return percentDecode(normalized) + } + + private fun extractFilename(value: String): String { + val withoutFragment = value.substringBefore('#').substringBefore('?').trimEnd('/') + val filename = withoutFragment.substringAfterLast('/', missingDelimiterValue = withoutFragment) + return percentDecode(filename).ifBlank { "artifact" } + } + + private fun percentDecode(value: String): String { + val out = StringBuilder(value.length) + var i = 0 + while (i < value.length) { + val c = value[i] + if (c == '%' && i + 2 < value.length) { + val decoded = hexByte(value[i + 1], value[i + 2]) + if (decoded != null) { + out.append(decoded.toInt().toChar()) + i += 3 + continue + } + } + out.append(c) + i++ + } + return out.toString() + } + + private fun hexByte(high: Char, low: Char): Byte? { + val hi = high.digitToIntOrNull(16) ?: return null + val lo = low.digitToIntOrNull(16) ?: return null + return ((hi shl 4) or lo).toByte() + } + + private fun cacheKey(provider: DataSourceProvider, normalizedUri: String, filename: String): String { + val safeName = filename.map { ch -> + if (ch.isLetterOrDigit() || ch == '.' || ch == '-' || ch == '_') ch else '_' + }.joinToString("") + return "${provider.name.lowercase()}-${fnv1a32Hex(normalizedUri)}-$safeName" + } + + private fun fnv1a32Hex(value: String): String { + var hash = FNV_OFFSET + val bytes = value.encodeToByteArray() + for (byte in bytes) { + hash = hash xor (byte.toInt() and 0xff) + hash *= FNV_PRIME + } + return hash.toHex8() + } + + private fun Int.toHex8(): String { + val chars = CharArray(8) + for (i in chars.indices) { + val shift = (7 - i) * 4 + chars[i] = HEX[(this ushr shift) and 0x0f] + } + return chars.concatToString() + } + + private const val FILE_URI_SCHEME = "file:" + private const val HTTP_PREFIX = "http://" + private const val HTTPS_PREFIX = "https://" + private const val HF_HTTPS_PREFIX = "hf+https://" + private const val HF_URI_PREFIX = "hf://" + private const val FNV_OFFSET = -2128831035 + private const val FNV_PRIME = 16777619 + private val HEX: CharArray = "0123456789abcdef".toCharArray() +} diff --git a/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt new file mode 100644 index 00000000..484de79d --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonMain/kotlin/sk/ainet/data/source/DefaultDataSourceResolver.kt @@ -0,0 +1,381 @@ +package sk.ainet.data.source + +import kotlinx.io.Buffer +import kotlinx.io.RawSource +import kotlinx.io.Sink +import kotlinx.io.Source +import kotlinx.io.buffered +import kotlinx.io.files.FileSystem +import kotlinx.io.files.Path +import kotlinx.io.files.SystemFileSystem +import kotlinx.io.readByteArray + +/** + * Remote response body exposed as a one-shot [Source]. + */ +public class DataSourceRemoteContent( + public val source: Source, + public val sizeBytes: Long? = null +) { + public companion object { + public fun fromBytes(bytes: ByteArray): DataSourceRemoteContent { + return DataSourceRemoteContent( + source = bytes.toDataSourceSource(), + sizeBytes = bytes.size.toLong() + ) + } + } +} + +/** + * Fetches a remote URI as a stream. Kept injectable so tests and applications + * can provide their own HTTP stack or policy layer. + */ +public fun interface RemoteDataSourceFetcher { + public suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent +} + +/** + * Adds platform or application-specific headers to a resolved remote request. + */ +public fun interface DataSourceHeaderProvider { + public fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map +} + +/** + * Supplies a Hugging Face token when a request does not carry one directly. + */ +public fun interface HuggingFaceTokenProvider { + public fun token(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): DataSourceAuthToken? +} + +/** + * Adds Hugging Face bearer auth from explicit request or resolver-level token + * configuration while leaving generic HTTP requests unchanged. + */ +public class HuggingFaceTokenHeaderProvider( + private val tokenProvider: HuggingFaceTokenProvider = HuggingFaceTokenProvider { _, _ -> null } +) : DataSourceHeaderProvider { + override fun headers(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): Map { + if (parsedUri.provider != DataSourceProvider.HuggingFace) return request.headers + if (request.headers.hasAuthorizationHeader()) return request.headers + val token = request.huggingFaceToken ?: tokenProvider.token(request, parsedUri) ?: return request.headers + return request.headers + (AUTHORIZATION_HEADER to token.authorizationHeaderValue()) + } + + private companion object { + private const val AUTHORIZATION_HEADER = "Authorization" + } +} + +/** + * Computes checksums for integrity verification without tying resolver policy + * to a concrete platform crypto API. + */ +public interface DataSourceChecksum { + public fun newSha256(): DataSourceHash +} + +/** + * Incremental hash state used while streaming artifact bytes. + */ +public interface DataSourceHash { + public fun update(bytes: ByteArray, startIndex: Int = 0, endIndex: Int = bytes.size) + public fun hex(): String +} + +/** + * Platform storage adapter used by [DefaultDataSourceResolver]. + */ +public interface DataSourceArtifactStore { + public suspend fun readLocal(path: String): DataSourceStoredArtifact? + public suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? + public suspend fun writeCache( + cacheKey: String, + source: Source, + sizeBytes: Long? = null, + validate: suspend (DataSourceStoredArtifact) -> Unit = {} + ): DataSourceStoredArtifact +} + +/** + * A materialized artifact used by the common resolver core. + */ +public class DataSourceStoredArtifact( + public val localPath: String?, + public val sizeBytes: Long?, + private val sourceOpener: suspend () -> Source +) { + public suspend fun openSource(): Source = sourceOpener() + + public suspend fun readBytes(): ByteArray { + val source = openSource() + return try { + source.readByteArray() + } finally { + source.close() + } + } + + public suspend fun copyTo(sink: Sink): Long { + val source = openSource() + return try { + source.transferTo(sink) + } finally { + source.close() + } + } + + public companion object { + public fun inMemory( + bytes: ByteArray, + localPath: String? = null + ): DataSourceStoredArtifact { + val owned = bytes.copyOf() + return DataSourceStoredArtifact( + localPath = localPath, + sizeBytes = owned.size.toLong(), + sourceOpener = { owned.toDataSourceSource() } + ) + } + + public fun inMemoryFrom( + source: Source, + localPath: String? = null, + sizeBytes: Long? = null + ): DataSourceStoredArtifact { + val buffer = Buffer() + val copied = source.transferTo(buffer) + return DataSourceStoredArtifact( + localPath = localPath, + sizeBytes = sizeBytes ?: copied, + sourceOpener = { buffer.copy() } + ) + } + } +} + +/** + * Filesystem-backed artifact store built on kotlinx-io so the cache policy + * remains reusable across KMP targets that expose [SystemFileSystem]. + */ +public class FileSystemDataSourceArtifactStore( + private val cacheDir: Path, + private val fileSystem: FileSystem = SystemFileSystem +) : DataSourceArtifactStore { + override suspend fun readLocal(path: String): DataSourceStoredArtifact? { + val localPath = Path(path) + val metadata = fileSystem.metadataOrNull(localPath) ?: return null + if (!metadata.isRegularFile) { + throw DataSourceException("Data source path is not a file: $path") + } + val resolved = fileSystem.resolve(localPath) + return resolved.toStoredArtifact(metadata.size) + } + + override suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? { + val target = Path(cacheDir, cacheKey) + val metadata = fileSystem.metadataOrNull(target) ?: return null + return if (metadata.isRegularFile) target.toStoredArtifact(metadata.size) else null + } + + override suspend fun writeCache( + cacheKey: String, + source: Source, + sizeBytes: Long?, + validate: suspend (DataSourceStoredArtifact) -> Unit + ): DataSourceStoredArtifact { + fileSystem.createDirectories(cacheDir) + val target = Path(cacheDir, cacheKey) + val temp = Path(cacheDir, "$cacheKey.tmp") + + val sink = fileSystem.sink(temp).buffered() + try { + source.transferTo(sink) + sink.flush() + } finally { + sink.close() + } + + val tempMetadata = fileSystem.metadataOrNull(temp) + val tempArtifact = temp.toStoredArtifact(tempMetadata?.size ?: sizeBytes) + try { + validate(tempArtifact) + } catch (throwable: Throwable) { + fileSystem.delete(temp, mustExist = false) + throw throwable + } + + fileSystem.atomicMove(temp, target) + val metadata = fileSystem.metadataOrNull(target) + return target.toStoredArtifact(metadata?.size ?: sizeBytes) + } + + private fun Path.toStoredArtifact(sizeBytes: Long?): DataSourceStoredArtifact { + val path = this + return DataSourceStoredArtifact( + localPath = path.toString(), + sizeBytes = sizeBytes, + sourceOpener = { fileSystem.source(path).buffered() } + ) + } +} + +/** + * Platform-neutral resolver implementation for local files, HTTP(S), and + * Hugging Face source URIs. Storage, network, auth, and checksum details are + * injected so this policy can be reused by each KMP target. + */ +public class DefaultDataSourceResolver( + private val store: DataSourceArtifactStore, + private val fetcher: RemoteDataSourceFetcher, + private val checksum: DataSourceChecksum, + private val headerProvider: DataSourceHeaderProvider = HuggingFaceTokenHeaderProvider() +) : DataSourceResolver { + override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact { + val parsed = DataSourceUriParser.parse(request.uri) + return when (parsed.provider) { + DataSourceProvider.File -> resolveFile(request, parsed) + DataSourceProvider.Http, DataSourceProvider.HuggingFace -> resolveRemote(request, parsed) + } + } + + private suspend fun resolveFile( + request: DataSourceRequest, + parsed: ParsedDataSourceUri + ): DataSourceArtifact { + val path = parsed.localPath ?: throw DataSourceException("File source has no local path: ${request.uri}") + val stored = store.readLocal(path) + ?: throw DataSourceException("Data source file not found: $path") + request.expectedSha256?.let { verifySha256(stored, it, request.uri) } + return stored.toArtifact(request, parsed, cacheHit = true) + } + + private suspend fun resolveRemote( + request: DataSourceRequest, + parsed: ParsedDataSourceUri + ): DataSourceArtifact { + val canUseCache = request.cachePolicy == CachePolicy.Use || request.cachePolicy == CachePolicy.Offline + if (canUseCache) { + val cached = store.readCache(parsed.cacheKey) + if (cached != null) { + request.expectedSha256?.let { verifySha256(cached, it, request.uri) } + return cached.toArtifact(request, parsed, cacheHit = true) + } + } + + if (request.cachePolicy == CachePolicy.Offline) { + throw DataSourceException("No cached artifact available for offline source: ${request.uri}") + } + + val remote = fetcher.fetch(parsed.transportUri, headerProvider.headers(request, parsed)) + + if (request.cachePolicy == CachePolicy.Bypass) { + val stored = try { + DataSourceStoredArtifact.inMemoryFrom(remote.source, sizeBytes = remote.sizeBytes) + } finally { + remote.source.close() + } + request.expectedSha256?.let { verifySha256(stored, it, request.uri) } + return stored.toArtifact(request, parsed, cacheHit = false) + } + + val expectedSha256 = request.expectedSha256 + val hash = expectedSha256?.let { checksum.newSha256() } + val source = hash?.let { HashingRawSource(remote.source, it).buffered() } ?: remote.source + val stored = try { + store.writeCache( + cacheKey = parsed.cacheKey, + source = source, + sizeBytes = remote.sizeBytes, + validate = { + if (expectedSha256 != null && hash != null) { + verifySha256Hex(hash.hex(), expectedSha256, request.uri) + } + } + ) + } finally { + source.close() + } + return stored.toArtifact(request, parsed, cacheHit = false) + } + + private suspend fun DataSourceStoredArtifact.toArtifact( + request: DataSourceRequest, + parsed: ParsedDataSourceUri, + cacheHit: Boolean + ): DataSourceArtifact { + return DataSourceArtifact( + request = request, + parsedUri = parsed, + filename = parsed.filename, + localPath = localPath, + sizeBytes = sizeBytes, + cacheHit = cacheHit, + sourceOpener = { openSource() } + ) + } + + private suspend fun verifySha256(artifact: DataSourceStoredArtifact, expected: String, uri: String) { + val actual = artifact.sha256Hex() + verifySha256Hex(actual, expected, uri) + } + + private fun verifySha256Hex(actual: String, expected: String, uri: String) { + if (!actual.equals(expected, ignoreCase = true)) { + throw DataSourceException( + "SHA-256 mismatch for $uri: expected ${expected.lowercase()}, actual $actual" + ) + } + } + + private suspend fun DataSourceStoredArtifact.sha256Hex(): String { + val hash = checksum.newSha256() + val buffer = ByteArray(STREAM_BUFFER_SIZE) + val source = openSource() + try { + while (true) { + val read = source.readAtMostTo(buffer) + if (read == -1) break + hash.update(buffer, endIndex = read) + } + } finally { + source.close() + } + return hash.hex() + } + + private companion object { + private const val STREAM_BUFFER_SIZE = 8 * 1024 + } +} + +private fun Map.hasAuthorizationHeader(): Boolean { + return keys.any { it.equals("Authorization", ignoreCase = true) } +} + +private class HashingRawSource( + private val source: Source, + private val hash: DataSourceHash +) : RawSource { + override fun readAtMostTo(sink: Buffer, byteCount: Long): Long { + val start = sink.size + val read = source.readAtMostTo(sink, byteCount) + if (read > 0) { + val copied = Buffer() + sink.copyTo(copied, startIndex = start, endIndex = start + read) + hash.update(copied.readByteArray()) + } + return read + } + + override fun close() { + source.close() + } +} + +private fun ByteArray.toDataSourceSource(): Source { + val buffer = Buffer() + buffer.write(this) + return buffer +} diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt new file mode 100644 index 00000000..7bcef66f --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DataSourceUriParserTest.kt @@ -0,0 +1,89 @@ +package sk.ainet.data.source + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotEquals +import kotlin.test.assertNull + +class DataSourceUriParserTest { + @Test + fun parsesFileUri() { + val parsed = DataSourceUriParser.parse("file:///tmp/skainet/train-images.idx") + + assertEquals(DataSourceProvider.File, parsed.provider) + assertEquals("/tmp/skainet/train-images.idx", parsed.localPath) + assertEquals("train-images.idx", parsed.filename) + } + + @Test + fun parsesJvmFileUri() { + val parsed = DataSourceUriParser.parse("file:/tmp/skainet/train-images.idx") + + assertEquals(DataSourceProvider.File, parsed.provider) + assertEquals("/tmp/skainet/train-images.idx", parsed.localPath) + assertEquals("train-images.idx", parsed.filename) + } + + @Test + fun parsesPlainPathAsFile() { + val parsed = DataSourceUriParser.parse("fixtures/mnist/train-labels.idx") + + assertEquals(DataSourceProvider.File, parsed.provider) + assertEquals("fixtures/mnist/train-labels.idx", parsed.localPath) + assertEquals("train-labels.idx", parsed.filename) + } + + @Test + fun parsesHttpUri() { + val parsed = DataSourceUriParser.parse("https://example.test/data/sample.csv?download=1") + + assertEquals(DataSourceProvider.Http, parsed.provider) + assertEquals("sample.csv", parsed.filename) + assertNull(parsed.huggingFace) + } + + @Test + fun parsesHuggingFaceHttpsProviderPrefix() { + val parsed = DataSourceUriParser.parse( + "hf+https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json" + ) + + assertEquals(DataSourceProvider.HuggingFace, parsed.provider) + assertEquals( + "https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct/resolve/main/tokenizer.json", + parsed.transportUri + ) + assertEquals("tokenizer.json", parsed.filename) + } + + @Test + fun parsesHuggingFaceDatasetShorthand() { + val parsed = DataSourceUriParser.parse("hf://datasets/mnist/mnist@main/plain_text/train-00000.parquet") + + assertEquals(DataSourceProvider.HuggingFace, parsed.provider) + assertEquals(HuggingFaceRepoType.Dataset, parsed.huggingFace?.repoType) + assertEquals("mnist/mnist", parsed.huggingFace?.repoId) + assertEquals("main", parsed.huggingFace?.revision) + assertEquals("plain_text/train-00000.parquet", parsed.huggingFace?.path) + assertEquals( + "https://huggingface.co/datasets/mnist/mnist/resolve/main/plain_text/train-00000.parquet", + parsed.transportUri + ) + } + + @Test + fun cacheKeyDependsOnNormalizedUri() { + val first = DataSourceUriParser.parse("https://example.test/a.txt") + val second = DataSourceUriParser.parse("https://example.test/b.txt") + + assertNotEquals(first.cacheKey, second.cacheKey) + } + + @Test + fun rejectsUnknownSchemes() { + assertFailsWith { + DataSourceUriParser.parse("s3://bucket/object") + } + } +} diff --git a/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt new file mode 100644 index 00000000..e3803588 --- /dev/null +++ b/skainet-data/skainet-data-source/src/commonTest/kotlin/sk/ainet/data/source/DefaultDataSourceResolverTest.kt @@ -0,0 +1,257 @@ +package sk.ainet.data.source + +import kotlinx.coroutines.test.runTest +import kotlinx.io.Buffer +import kotlinx.io.Source +import kotlinx.io.readByteArray +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class DefaultDataSourceResolverTest { + @Test + fun resolvesLocalArtifactsThroughStore() = runTest { + val store = MemoryDataSourceByteStore( + localArtifacts = mapOf("/data/sample.bin" to "local".encodeToByteArray()) + ) + val fetcher = RecordingFetcher("remote".encodeToByteArray()) + val resolver = DefaultDataSourceResolver(store, fetcher, TestChecksum) + + val artifact = resolver.resolve(DataSourceRequest("/data/sample.bin")) + + assertEquals("/data/sample.bin", artifact.localPath) + assertTrue(artifact.cacheHit) + assertEquals(0, fetcher.calls) + assertContentEquals("local".encodeToByteArray(), artifact.readBytes()) + + val copied = Buffer() + assertEquals(5, artifact.copyTo(copied)) + assertContentEquals("local".encodeToByteArray(), copied.readByteArray()) + } + + @Test + fun cachesRemoteArtifactsAndReusesThem() = runTest { + val store = MemoryDataSourceByteStore() + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver(store, fetcher, TestChecksum) + val request = DataSourceRequest( + uri = "https://example.test/data.bin", + expectedSha256 = "sha:payload" + ) + + val first = resolver.resolve(request) + val second = resolver.resolve(request) + + assertFalse(first.cacheHit) + assertTrue(second.cacheHit) + assertEquals(1, fetcher.calls) + assertEquals(1, store.cacheWrites) + assertContentEquals("payload".encodeToByteArray(), second.readBytes()) + } + + @Test + fun bypassSkipsPersistentCache() = runTest { + val store = MemoryDataSourceByteStore() + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver(store, fetcher, TestChecksum) + + val artifact = resolver.resolve( + DataSourceRequest( + uri = "https://example.test/data.bin", + cachePolicy = CachePolicy.Bypass + ) + ) + + assertEquals(null, artifact.localPath) + assertFalse(artifact.cacheHit) + assertEquals(1, fetcher.calls) + assertEquals(0, store.cacheWrites) + } + + @Test + fun verifiesChecksumsInCommonCore() = runTest { + val store = MemoryDataSourceByteStore() + val resolver = DefaultDataSourceResolver( + store = store, + fetcher = RecordingFetcher("payload".encodeToByteArray()), + checksum = TestChecksum + ) + + assertFailsWith { + resolver.resolve( + DataSourceRequest( + uri = "https://example.test/data.bin", + expectedSha256 = "sha:other" + ) + ) + } + assertEquals(0, store.cacheWrites) + } + + @Test + fun forwardsProviderHeadersToFetcher() = runTest { + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = fetcher, + checksum = TestChecksum, + headerProvider = DataSourceHeaderProvider { request, parsedUri -> + request.headers + ("X-SKaiNET-Provider" to parsedUri.provider.name) + } + ) + + resolver.resolve( + DataSourceRequest( + uri = "hf://datasets/org/repo@main/file.bin", + headers = mapOf("Accept" to "application/octet-stream") + ) + ) + + assertEquals( + mapOf( + "Accept" to "application/octet-stream", + "X-SKaiNET-Provider" to "HuggingFace" + ), + fetcher.lastHeaders + ) + } + + @Test + fun addsHuggingFaceTokenFromRequest() = runTest { + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = fetcher, + checksum = TestChecksum + ) + val token = DataSourceAuthToken.from("hf_request") + + resolver.resolve( + DataSourceRequest( + uri = "hf://datasets/org/repo@main/file.bin", + headers = mapOf("Accept" to "application/octet-stream"), + huggingFaceToken = token + ) + ) + + assertEquals( + mapOf( + "Accept" to "application/octet-stream", + "Authorization" to "Bearer hf_request" + ), + fetcher.lastHeaders + ) + assertEquals("DataSourceAuthToken(***)", token.toString()) + } + + @Test + fun keepsExistingAuthorizationHeaderOverHuggingFaceToken() = runTest { + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = fetcher, + checksum = TestChecksum + ) + + resolver.resolve( + DataSourceRequest( + uri = "hf://org/repo@main/file.bin", + headers = mapOf("authorization" to "Bearer explicit"), + huggingFaceToken = DataSourceAuthToken.from("hf_request") + ) + ) + + assertEquals(mapOf("authorization" to "Bearer explicit"), fetcher.lastHeaders) + } + + @Test + fun doesNotAddHuggingFaceTokenToGenericHttp() = runTest { + val fetcher = RecordingFetcher("payload".encodeToByteArray()) + val resolver = DefaultDataSourceResolver( + store = MemoryDataSourceByteStore(), + fetcher = fetcher, + checksum = TestChecksum + ) + + resolver.resolve( + DataSourceRequest( + uri = "https://example.test/data.bin", + huggingFaceToken = DataSourceAuthToken.from("hf_request") + ) + ) + + assertEquals(emptyMap(), fetcher.lastHeaders) + } +} + +private class MemoryDataSourceByteStore( + private val localArtifacts: Map = emptyMap() +) : DataSourceArtifactStore { + private val cacheArtifacts = mutableMapOf() + + var cacheWrites: Int = 0 + private set + + override suspend fun readLocal(path: String): DataSourceStoredArtifact? { + return localArtifacts[path]?.storedAt(path) + } + + override suspend fun readCache(cacheKey: String): DataSourceStoredArtifact? { + return cacheArtifacts[cacheKey] + } + + override suspend fun writeCache( + cacheKey: String, + source: Source, + sizeBytes: Long?, + validate: suspend (DataSourceStoredArtifact) -> Unit + ): DataSourceStoredArtifact { + val stored = DataSourceStoredArtifact.inMemoryFrom( + source = source, + localPath = "/cache/$cacheKey", + sizeBytes = sizeBytes + ) + validate(stored) + cacheWrites++ + cacheArtifacts[cacheKey] = stored + return stored + } + + private fun ByteArray.storedAt(path: String): DataSourceStoredArtifact { + val bytes = copyOf() + return DataSourceStoredArtifact.inMemory(bytes, localPath = path) + } +} + +private class RecordingFetcher( + private val bytes: ByteArray +) : RemoteDataSourceFetcher { + var calls: Int = 0 + private set + + var lastHeaders: Map = emptyMap() + private set + + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { + calls++ + lastHeaders = headers + return DataSourceRemoteContent.fromBytes(bytes.copyOf()) + } +} + +private object TestChecksum : DataSourceChecksum { + override fun newSha256(): DataSourceHash = TestHash() +} + +private class TestHash : DataSourceHash { + private val text = StringBuilder() + + override fun update(bytes: ByteArray, startIndex: Int, endIndex: Int) { + text.append(bytes.copyOfRange(startIndex, endIndex).decodeToString()) + } + + override fun hex(): String = "sha:$text" +} diff --git a/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt new file mode 100644 index 00000000..fd7c9a88 --- /dev/null +++ b/skainet-data/skainet-data-source/src/jvmMain/kotlin/sk/ainet/data/source/JvmDataSourceResolver.kt @@ -0,0 +1,109 @@ +package sk.ainet.data.source + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.request.get +import io.ktor.client.request.header +import io.ktor.client.statement.bodyAsChannel +import io.ktor.http.HttpHeaders +import io.ktor.utils.io.asSource +import kotlinx.io.buffered +import kotlinx.io.files.Path +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import java.io.File +import java.security.MessageDigest + +/** + * Ktor/CIO-backed remote fetcher for JVM data artifacts. + */ +public class KtorRemoteDataSourceFetcher( + private val client: HttpClient = HttpClient(CIO) { + expectSuccess = true + install(HttpTimeout) { + requestTimeoutMillis = 600_000 + connectTimeoutMillis = 60_000 + socketTimeoutMillis = 600_000 + } + } +) : RemoteDataSourceFetcher, AutoCloseable { + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { + val response = client.get(uri) { + headers.forEach { (name, value) -> header(name, value) } + } + return DataSourceRemoteContent( + source = response.bodyAsChannel().asSource().buffered(), + sizeBytes = response.headers[HttpHeaders.ContentLength]?.toLongOrNull() + ) + } + + override fun close() { + client.close() + } +} + +/** + * JVM resolver for local files and cached remote artifacts. + */ +public class JvmDataSourceResolver( + cacheDir: File = defaultCacheDir(), + fetcher: RemoteDataSourceFetcher = KtorRemoteDataSourceFetcher(), + huggingFaceToken: DataSourceAuthToken? = null, + useEnvironmentHuggingFaceToken: Boolean = false +) : DataSourceResolver { + private val delegate = DefaultDataSourceResolver( + store = FileSystemDataSourceArtifactStore(Path(cacheDir.absolutePath)), + fetcher = fetcher, + checksum = JvmSha256DataSourceChecksum, + headerProvider = HuggingFaceTokenHeaderProvider( + JvmHuggingFaceTokenProvider( + configuredToken = huggingFaceToken, + useEnvironmentToken = useEnvironmentHuggingFaceToken + ) + ) + ) + + override suspend fun resolve(request: DataSourceRequest): DataSourceArtifact = withContext(Dispatchers.IO) { + delegate.resolve(request) + } + + public companion object { + public fun defaultCacheDir(): File { + val userHome = System.getProperty("user.home")?.takeIf { it.isNotBlank() } + val base = userHome ?: System.getProperty("java.io.tmpdir") + return File(base, ".cache/skainet/data") + } + } +} + +internal class JvmHuggingFaceTokenProvider( + private val configuredToken: DataSourceAuthToken?, + private val useEnvironmentToken: Boolean +) : HuggingFaceTokenProvider { + override fun token(request: DataSourceRequest, parsedUri: ParsedDataSourceUri): DataSourceAuthToken? { + if (parsedUri.provider != DataSourceProvider.HuggingFace) return null + configuredToken?.let { return it } + if (!useEnvironmentToken) return null + return DataSourceAuthToken.fromOrNull(System.getenv("HF_TOKEN")) + ?: DataSourceAuthToken.fromOrNull(System.getenv("HUGGING_FACE_HUB_TOKEN")) + } +} + +internal object JvmSha256DataSourceChecksum : DataSourceChecksum { + override fun newSha256(): DataSourceHash = JvmSha256DataSourceHash() +} + +private class JvmSha256DataSourceHash : DataSourceHash { + private val digest = MessageDigest.getInstance("SHA-256") + + override fun update(bytes: ByteArray, startIndex: Int, endIndex: Int) { + digest.update(bytes, startIndex, endIndex - startIndex) + } + + override fun hex(): String { + return digest + .digest() + .joinToString("") { byte -> "%02x".format(byte) } + } +} diff --git a/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt new file mode 100644 index 00000000..4f997cd4 --- /dev/null +++ b/skainet-data/skainet-data-source/src/jvmTest/kotlin/sk/ainet/data/source/JvmDataSourceResolverTest.kt @@ -0,0 +1,211 @@ +package sk.ainet.data.source + +import kotlinx.coroutines.test.runTest +import java.nio.file.Files +import kotlin.test.Test +import kotlin.test.assertContentEquals +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class JvmDataSourceResolverTest { + @Test + fun resolvesLocalFileUri() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val file = root.resolve("sample.txt") + file.writeText("hello") + val resolver = JvmDataSourceResolver(cacheDir = root.resolve("cache")) + + val artifact = resolver.resolve(DataSourceRequest(file.toURI().toString())) + + assertEquals("sample.txt", artifact.filename) + assertEquals(file.canonicalPath, artifact.localPath) + assertTrue(artifact.cacheHit) + assertContentEquals("hello".encodeToByteArray(), artifact.readBytes()) + } finally { + root.deleteRecursively() + } + } + + @Test + fun cachesRemoteArtifacts() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = FakeFetcher("first".encodeToByteArray()) + val resolver = JvmDataSourceResolver(cacheDir = root.resolve("cache"), fetcher = fetcher) + val request = DataSourceRequest( + "hf+https://huggingface.co/example/model/resolve/main/config.json" + ) + + val first = resolver.resolve(request) + val second = resolver.resolve(request) + + assertEquals(1, fetcher.calls) + assertFalse(first.cacheHit) + assertTrue(second.cacheHit) + assertNotNull(second.localPath) + assertContentEquals("first".encodeToByteArray(), second.readBytes()) + } finally { + root.deleteRecursively() + } + } + + @Test + fun refreshFetchesAgain() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = QueueFetcher( + "old".encodeToByteArray(), + "new".encodeToByteArray() + ) + val resolver = JvmDataSourceResolver(cacheDir = root.resolve("cache"), fetcher = fetcher) + val uri = "https://example.test/data.bin" + + resolver.resolve(DataSourceRequest(uri)).readBytes() + val refreshed = resolver.resolve(DataSourceRequest(uri, cachePolicy = CachePolicy.Refresh)) + + assertEquals(2, fetcher.calls) + assertContentEquals("new".encodeToByteArray(), refreshed.readBytes()) + } finally { + root.deleteRecursively() + } + } + + @Test + fun offlineFailsWhenCacheIsMissing() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val resolver = JvmDataSourceResolver(cacheDir = root.resolve("cache"), fetcher = FakeFetcher(ByteArray(0))) + + assertFailsWith { + resolver.resolve( + DataSourceRequest( + uri = "https://example.test/missing.bin", + cachePolicy = CachePolicy.Offline + ) + ) + } + } finally { + root.deleteRecursively() + } + } + + @Test + fun bypassDoesNotWriteCache() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = FakeFetcher("bytes".encodeToByteArray()) + val cacheDir = root.resolve("cache") + val resolver = JvmDataSourceResolver(cacheDir = cacheDir, fetcher = fetcher) + + val artifact = resolver.resolve( + DataSourceRequest("https://example.test/data.bin", cachePolicy = CachePolicy.Bypass) + ) + + assertEquals(1, fetcher.calls) + assertEquals(null, artifact.localPath) + assertFalse(cacheDir.exists()) + } finally { + root.deleteRecursively() + } + } + + @Test + fun verifiesSha256() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val resolver = JvmDataSourceResolver( + cacheDir = root.resolve("cache"), + fetcher = FakeFetcher("payload".encodeToByteArray()) + ) + + assertFailsWith { + resolver.resolve( + DataSourceRequest( + uri = "https://example.test/payload.bin", + expectedSha256 = "0000" + ) + ) + } + } finally { + root.deleteRecursively() + } + } + + @Test + fun sendsConfiguredHuggingFaceToken() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = FakeFetcher("payload".encodeToByteArray()) + val resolver = JvmDataSourceResolver( + cacheDir = root.resolve("cache"), + fetcher = fetcher, + huggingFaceToken = DataSourceAuthToken.from("hf_configured"), + useEnvironmentHuggingFaceToken = false + ) + + resolver.resolve(DataSourceRequest("hf://org/repo@main/file.bin")) + + assertEquals(mapOf("Authorization" to "Bearer hf_configured"), fetcher.lastHeaders) + } finally { + root.deleteRecursively() + } + } + + @Test + fun requestHuggingFaceTokenOverridesConfiguredToken() = runTest { + val root = Files.createTempDirectory("skainet-data-source-test").toFile() + try { + val fetcher = FakeFetcher("payload".encodeToByteArray()) + val resolver = JvmDataSourceResolver( + cacheDir = root.resolve("cache"), + fetcher = fetcher, + huggingFaceToken = DataSourceAuthToken.from("hf_configured"), + useEnvironmentHuggingFaceToken = false + ) + + resolver.resolve( + DataSourceRequest( + uri = "hf://org/repo@main/file.bin", + huggingFaceToken = DataSourceAuthToken.from("hf_request") + ) + ) + + assertEquals(mapOf("Authorization" to "Bearer hf_request"), fetcher.lastHeaders) + } finally { + root.deleteRecursively() + } + } +} + +private class FakeFetcher( + private val bytes: ByteArray +) : RemoteDataSourceFetcher { + var calls: Int = 0 + private set + + var lastHeaders: Map = emptyMap() + private set + + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { + calls++ + lastHeaders = headers + return DataSourceRemoteContent.fromBytes(bytes) + } +} + +private class QueueFetcher( + private vararg val responses: ByteArray +) : RemoteDataSourceFetcher { + var calls: Int = 0 + private set + + override suspend fun fetch(uri: String, headers: Map): DataSourceRemoteContent { + val index = calls.coerceAtMost(responses.lastIndex) + calls++ + return DataSourceRemoteContent.fromBytes(responses[index]) + } +}