Skip to content

NN with Multimodal inputs #201

Description

@michalharakal

Multimodal Input Models in Skainet

Last updated: 2025-11-17


Problem statement

You want to build models that take multiple inputs (e.g., an image tensor and an auxiliary vector), process them in separate branches, and then fuse them (typically by flattening the spatial branch and concatenating features), continuing with shared layers — all ideally defined as a single computation.

Example target flow:

  • Image branch: Conv2D/Pool → Flatten → features (N, F)
  • Vector branch: Pre-existing vector features (N, D)
  • Fusion: Concat(features, dim = -1) → (N, F + D)
  • Head: Dense/MLP → Output

Current status (what works today)

  • Core ops support the necessary building blocks:
    • Flatten: preserve batch dimension; collapse spatial dims.
    • Concat: concatenate tensors along any dimension (supports negative dim, shape checks).
    • Dense, Conv2D, MaxPool2D, Activations are available.
  • There are three practical ways to implement multimodal fusion today or proposed for near-term use:
    1. Without any new API (compose two modules and concat at call site)
    2. With a minimal Functional API for multi-input single-computation style
    3. Proposed DSL extensions to support branching, named inputs, and Concatenate in one definition block

Solutions overview

  1. No new API: compose modules + concat at call site (available now)
  • Build your CNN image branch as a sequential DSL module that ends with flatten().
  • Build your dense head as another sequential DSL module.
  • At the call site, run the image branch on the image tensor, then concat its flattened result with the auxiliary vector along the last dimension, then run the dense head.
  • Pros: works today with existing DSL and ops; minimal surface area.
  • Cons: wiring lives outside the model definition; not a single definition block.

Key snippets:

  • Flatten (DSL): stage("flatten") { flatten() } → shape (N, C⋅H⋅W)
  • Concat (ops): ops.concat(listOf(a, b), dim = -1) → shape (N, F + D)

Related test/example in repo:

  • skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/nn/FeatureFusionFlattenConcatTest.kt
  1. Minimal Functional API (committed) for single-computation, multi-input graphs (available now)
  • A lightweight Functional wrapper allows you to describe the entire forward pass as a single function of named inputs.
  • Retrieve inputs by name, run any sequence of ops (Conv2D/Pool/Flatten/Concat/Dense), and return the final tensor.
  • Pros: Expresses multi-input models as a single computation; easy to branch/merge; stays close to common “functional API” patterns.
  • Cons: Concatenation is still via ops (no layer node); not integrated into the sequential DSL builder.

Files:

  • API: skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/Functional.kt
  • Example test: skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/nn/FunctionalMultiInputModelTest.kt

Usage sketch:

  • Declare inputs: Functional.of(inputs = listOf(FuncInput("matrixInput"), FuncInput("vectorInput"))) { args, ctx -> ... }
  • Inside, do: Conv2D → MaxPool2D → Flatten → ops.concat(listOf(flat, vector), dim = -1) → Dense → Output
  1. Proposed DSL with branching and named inputs (design/roadmap)
    Goal: Enable multimodal fusion within a single DSL definition block, without external glue code or a separate Functional wrapper.

Minimal additions required:

  • Multi-input entrypoint: a model interface that accepts a feed of named input tensors, e.g., forward(feed: Map<String, Tensor<*, *>>, ctx).
  • DSL constructs:
    • inputs { input(name = "matrixInput", shape = …); input(name = "vectorInput", shape = …) }
    • branch(from = …) { … } to build subgraphs from named inputs or intermediate nodes
    • output(node) to select the final node
  • Concatenate node in the DSL:
    • val merged = concat(dim = -1) { use(nodeA); use(nodeB) }
    • Internally calls existing ops.concat at execution

Example desired usage:

val model = definition<FP32, Float> {
    network {
        inputs {
            input(name = "matrixInput", shape = Shape(-1, 1, 28, 28))
            input(name = "vectorInput", shape = Shape(-1, 10))
        }

        val image = branch(from = "matrixInput") {
            conv2d(outChannels = 16, kernelSize = 3 to 3)
            maxPool2d(kernelSize = 2 to 2, stride = 2 to 2)
            flatten()
        }

        val merged = concat(dim = -1) {
            use(image)
            use(from("vectorInput"))
        }

        val head = branch(from = merged) {
            dense(outputDimension = 32)
            dense(outputDimension = 1)
        }

        output(head)
    }
}

Status: Not implemented yet. Requires:

  • A MultiInputModule interface (or equivalent) and a small internal DAG builder
  • DSL surface for inputs/branch/concat/output
  • Unit tests for shape correctness and error cases

Practical guidance

  • Batch alignment: Both branches must share the same batch size N before concatenation.
  • Flatten correctly: For BCHW inputs, flatten with startDim = 1, endDim = -1 (the DSL’s flatten() does this), so the batch dimension stays intact.
  • Concat axis: Use dim = -1 to append features along the last dimension (typical for feature fusion).
  • Mismatched sizes: If feature sizes don’t match downstream expectations, insert a Dense/Linear projection before fusion or in the head.
  • Backend note: Default Void backend performs shape-only ops; good for prototyping/tests. Numeric backends can be integrated later.

References in this repository

  • Tests and examples
    • Feature flatten + concat using DSL tensor creation and ops:
      • skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/nn/FeatureFusionFlattenConcatTest.kt
    • Single-computation multi-input example using Functional API:
      • skainet-lang/skainet-lang-core/src/commonTest/kotlin/sk/ainet/lang/nn/FunctionalMultiInputModelTest.kt
  • API and core modules
    • Functional API: skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/Functional.kt
    • Layers and DSL builder (Conv2d, Linear, Flatten, MaxPool2d, etc.):
      • skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/
      • skainet-lang/skainet-lang-core/src/commonMain/kotlin/sk/ainet/lang/nn/dsl/NetworkBuilder.kt

Limitations and roadmap

  • Today:
    • There is no Concatenate layer node in the DSL; concatenate via ops.concat.
    • The sequential DSL is single-stream; multi-input branching needs either external wiring (solution 1) or the Functional wrapper (solution 2).
  • Near-term improvements (proposed):
    • Introduce a minimal MultiInputModule + DAG-aware DSL with inputs {}, branch {}, concat {} and output().
    • Add unit tests for multi-input DSL graphs and concat shape validations.
    • Optional: Multi-output heads and better error messages for missing inputs.

TL;DR

  • You can implement multimodal fusion today either by composing two DSL modules and concatenating at the call site, or by using the minimal Functional API to define the whole computation in one place.
  • A future DSL enhancement can bring first-class branching, named inputs, and concatenate into a single definition block for an even cleaner model declaration.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions