From 69f3f1f061caa4cf12bb3e6d0bcec9a6dafb3b92 Mon Sep 17 00:00:00 2001 From: Adnan AlSinan Date: Thu, 19 Feb 2026 18:16:41 +0000 Subject: [PATCH 1/2] Update audiogen export script Signed-off-by: Adnan AlSinan --- .../audiogen/install_requirements.sh | 24 +++++++++---------- .../scripts/export_dit_autoencoder.py | 22 +++++++++-------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/kleidiai-examples/audiogen/install_requirements.sh b/kleidiai-examples/audiogen/install_requirements.sh index b7986d8..de39725 100644 --- a/kleidiai-examples/audiogen/install_requirements.sh +++ b/kleidiai-examples/audiogen/install_requirements.sh @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # @@ -9,31 +9,31 @@ # Install individual packages echo "Installing required packages for the Audiogen module..." -# ai-edge-torch -pip install ai-edge-torch==0.4.0 \ - "tf-nightly>=2.19.0.dev20250208" \ - "ai-edge-litert-nightly>=1.1.2.dev20250305" \ - "ai-edge-quantizer-nightly>=0.0.1.dev20250208" +# LiteRT torch +pip install litert-torch==0.8.0 \ + "ai-edge-litert==2.1.2" \ + "ai-edge-quantizer==0.4.2" # Stable audio tools pip install "stable_audio_tools==0.0.19" # Working out dependency issues, this combination of packages has been tested on different systems (Linux and MacOS). -pip install --no-deps "torch==2.6.0" \ - "torchaudio==2.6.0" \ - "torchvision==0.21.0" \ - "protobuf==5.29.4" \ +pip install --no-deps "torch==2.9.0" \ + "torchaudio==2.9.0" \ + "torchvision==0.24.0" \ + "protobuf==5.29.6" \ "numpy==1.26.4" \ -# Packages to convert via onnx +# Packages to convert via onnx pip install --no-deps "onnx==1.18.0" \ "onnxsim==0.4.36" \ + "onnx-ir==0.1.16" \ "onnx2tf==1.27.10" \ + "onnxscript==0.6.2" \ "tensorflow==2.19.0" \ "tf_keras==2.19.0" \ "onnx-graphsurgeon==0.5.8" \ - "ai_edge_litert" \ "sng4onnx==1.0.4" echo "Finished installing required packages for AudioGen submodules conversion." diff --git a/kleidiai-examples/audiogen/scripts/export_dit_autoencoder.py b/kleidiai-examples/audiogen/scripts/export_dit_autoencoder.py index 639363d..4be0e8e 100644 --- a/kleidiai-examples/audiogen/scripts/export_dit_autoencoder.py +++ b/kleidiai-examples/audiogen/scripts/export_dit_autoencoder.py @@ -1,26 +1,28 @@ # -# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +# SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates # # SPDX-License-Identifier: Apache-2.0 # +# Disable GPU to avoid any issues during export +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "" + import argparse import json import logging -import os -import ai_edge_torch +import litert_torch import torch from einops import rearrange -from ai_edge_torch.generative.quantize import quant_recipe, quant_recipe_utils -from ai_edge_torch.quantize import quant_config +from litert_torch.generative.quantize import quant_recipe, quant_recipe_utils +from litert_torch.quantize import quant_config from utils_load_model import load_model import stable_audio_tools -os.environ["CUDA_VISIBLE_DEVICES"] = "" torch.manual_seed(0) DEVICE = torch.device("cpu") @@ -157,7 +159,7 @@ def export_audiogen(args) -> None: # Create the dynamic weights int8 quantization config quant_config_audiogen_int8 = quant_config.QuantConfig( generative_recipe=quant_recipe.GenerativeQuantRecipe( - default=quant_recipe_utils.create_layer_quant_int8_dynamic(), + default=quant_recipe_utils.create_layer_quant_dynamic(), ) ) @@ -178,7 +180,7 @@ def rotary_emb_const(_): dit_model.model.transformer.rotary_pos_emb.forward_from_seq_len = rotary_emb_const # Export the DiT to LiteRT format - edge_model = ai_edge_torch.convert( + edge_model = litert_torch.convert( dit_model, sample_args=None, sample_kwargs=dit_model_example_input, quant_config=quant_config_audiogen_int8 ) edge_model.export("./dit_model.tflite") @@ -192,7 +194,7 @@ def rotary_emb_const(_): autoencoder_decoder_example_input = get_autoencoder_decoder_example_input(dtype) # Export the Encoder part of the AutoEncoder to LiteRT format - edge_model = ai_edge_torch.convert( + edge_model = litert_torch.convert( autoencoder_decoder, autoencoder_decoder_example_input, ) @@ -209,7 +211,7 @@ def rotary_emb_const(_): autoencoder_encoder_example_input = get_autoencoder_encoder_example_input(dtype) # Export the AutoEncoder to LiteRT format - edge_model = ai_edge_torch.convert( + edge_model = litert_torch.convert( autoencoder_encoder, autoencoder_encoder_example_input, ) From 532aa58b9be6e19e5f7dcb3353e1e9e814310f43 Mon Sep 17 00:00:00 2001 From: Adnan AlSinan Date: Wed, 13 May 2026 16:43:21 +0100 Subject: [PATCH 2/2] Update READMEs and conversion script Signed-off-by: Adnan AlSinan --- kleidiai-examples/audiogen/README.md | 6 +- .../audiogen/install_requirements.sh | 28 +- kleidiai-examples/audiogen/scripts/README.md | 139 ++----- .../audiogen/scripts/export_conditioners.py | 289 -------------- .../scripts/export_dit_autoencoder.py | 246 ------------ .../audiogen/scripts/export_sao.py | 181 +++++++++ kleidiai-examples/audiogen/scripts/model.py | 368 ++++++++++++++++++ .../audiogen/scripts/utils_load_model.py | 97 ----- 8 files changed, 581 insertions(+), 773 deletions(-) delete mode 100644 kleidiai-examples/audiogen/scripts/export_conditioners.py delete mode 100644 kleidiai-examples/audiogen/scripts/export_dit_autoencoder.py create mode 100644 kleidiai-examples/audiogen/scripts/export_sao.py create mode 100644 kleidiai-examples/audiogen/scripts/model.py delete mode 100644 kleidiai-examples/audiogen/scripts/utils_load_model.py diff --git a/kleidiai-examples/audiogen/README.md b/kleidiai-examples/audiogen/README.md index 212baf4..66e10ca 100644 --- a/kleidiai-examples/audiogen/README.md +++ b/kleidiai-examples/audiogen/README.md @@ -1,5 +1,5 @@ @@ -8,8 +8,8 @@ Welcome to the home of audio generation on Arm® CPUs, featuring Stable Audio Open Small with Arm® KleidiAI™. This project provides everything you need to: -- Convert models to LiteRT-compatible formats -- Run these models on Arm® CPUs using the LiteRT runtime, with support from XNNPack and Arm® KleidiAI™ +- Convert models to LiteRT formats using LiteRT Torch. +- Run these models on Arm® CPUs using the LiteRT runtime, with support from XNNPACK and Arm® KleidiAI™. ## Prerequisites diff --git a/kleidiai-examples/audiogen/install_requirements.sh b/kleidiai-examples/audiogen/install_requirements.sh index de39725..7d2189d 100644 --- a/kleidiai-examples/audiogen/install_requirements.sh +++ b/kleidiai-examples/audiogen/install_requirements.sh @@ -9,33 +9,15 @@ # Install individual packages echo "Installing required packages for the Audiogen module..." -# LiteRT torch -pip install litert-torch==0.8.0 \ - "ai-edge-litert==2.1.2" \ - "ai-edge-quantizer==0.4.2" - # Stable audio tools pip install "stable_audio_tools==0.0.19" +# LiteRT Torch +pip install "litert-torch==0.9.0" -# Working out dependency issues, this combination of packages has been tested on different systems (Linux and MacOS). -pip install --no-deps "torch==2.9.0" \ - "torchaudio==2.9.0" \ - "torchvision==0.24.0" \ - "protobuf==5.29.6" \ - "numpy==1.26.4" \ - -# Packages to convert via onnx -pip install --no-deps "onnx==1.18.0" \ - "onnxsim==0.4.36" \ - "onnx-ir==0.1.16" \ - "onnx2tf==1.27.10" \ - "onnxscript==0.6.2" \ - "tensorflow==2.19.0" \ - "tf_keras==2.19.0" \ - "onnx-graphsurgeon==0.5.8" \ - "sng4onnx==1.0.4" +# stable_audio_tools has a dependency on numpy 1.26.4, we need this version, otherwise it fails. +pip install --no-deps "numpy==1.26.4" echo "Finished installing required packages for AudioGen submodules conversion." echo "To start converting the Conditioners, DiT and Autoencoder modules conversion, use the following command:" -echo "python ./scripts/export_{MODEL-T0-CONVERT}.py" +echo "python ./scripts/export_sao.py" diff --git a/kleidiai-examples/audiogen/scripts/README.md b/kleidiai-examples/audiogen/scripts/README.md index f56990a..8a6be27 100644 --- a/kleidiai-examples/audiogen/scripts/README.md +++ b/kleidiai-examples/audiogen/scripts/README.md @@ -1,5 +1,5 @@ @@ -7,7 +7,7 @@ # Building and Running the Audio Generation Application on Arm® CPUs with the Stable Audio Open Small Model ## Goal -This guide will show you how to convert the Stable Audio Open Small Model to LiteRT-compatible form to run on Arm® CPUs with the LiteRT runtime. +This guide will show you how to convert the Stable Audio Open Small Model to LiteRT format to run on Arm® CPUs with the LiteRT runtime. ### Converting the Stable Audio Open Small Model to LiteRT format The Stable Audio Open Small Model is made of three submodules: @@ -15,11 +15,9 @@ The Stable Audio Open Small Model is made of three submodules: - Diffusion Transformer (DiT) - AutoEncoder. -You will explore two different conversion routes, to convert the submodules to LiteRT format. +You will explore how to use LiteRT torch for those models. -1. __ONNX → LiteRT__ using the [onnx2tf](https://github.com/PINTO0309/onnx2tf) tool. This is the traditional two-step approach (PyTorchONNXLiteRT). You will use it to convert the Conditioners submodule. - -2. __PyTorch → LiteRT__ using the [Google AI Edge Torch](https://developers.googleblog.com/en/ai-edge-torch-high-performance-inference-of-pytorch-models-on-mobile-devices/) tool. This method, currently under active development, aims to simplify the conversion by performing it in a single step. You will use this tool to convert the DiT and AutoEncoder submodules. +__PyTorch → LiteRT__ using the [LiteRT Torch](https://github.com/google-ai-edge/litert-torch) tool. This tool aims to simplify the conversion and the quantization of torch models to LiteRT, for easy deployment on edge devices. ### Create a virtual environment and install dependencies. @@ -41,136 +39,47 @@ bash install_requirements.sh Option B ```bash # Option B (with .venv activated) -# Packages for the ai-edge-torch tool -pip install ai-edge-torch==0.4.0 \ - "tf-nightly>=2.19.0.dev20250208" \ - "ai-edge-litert-nightly>=1.1.2.dev20250305" \ - "ai-edge-quantizer-nightly>=0.0.1.dev20250208" -# Stable-Audio Tools +# Stable audio tools pip install "stable_audio_tools==0.0.19" -# Working out dependency issues, this combination of packages has been tested on different systems (Linux® and macOS®). -pip install --no-deps "torch==2.6.0" \ - "torchaudio==2.6.0" \ - "torchvision==0.21.0" \ - "protobuf==5.29.4" \ - "numpy==1.26.4" \ - -# Packages to convert using ONNX -pip install --no-deps onnx \ - onnxsim \ - onnx2tf \ - tensorflow \ - tf_keras \ - onnx_graphsurgeon \ - ai_edge_litert \ - sng4onnx -``` +# Install LiteRT Torch +pip install "litert-torch==0.9.0" -> [!NOTE] -> -> If you are using GPU on your machine, you might faced the following error: -> ```bash -> Traceback (most recent call last): -> File "/home//Workspace/tflite/env3_10/lib/python3.10/site-packages/torch/_inductor/runtime/hints.py", line 46, in -> from triton.backends.compiler import AttrsDescriptor -> ImportError: cannot import name 'AttrsDescriptor' from 'triton.backends.compiler' (/home//Workspace/tflite/env3_10/lib/> python3.10/site-packages/triton/backends/compiler.py) -> -> During handling of the above exception, another exception occurred: -> -> Traceback (most recent call last): -> File "/home//Workspace/tflite/audiogen/./scripts/export_dit_autoencoder.py", line 6, in -> import ai_edge_torch -> File "/home//Workspace/tflite/env3_10/lib/python3.10/site-packages/ai_edge_torch/__init__.py", line 16, in -> from ai_edge_torch._convert.converter import convert -> File "/home//Workspace/tflite/env3_10/lib/python3.10/site-packages/ai_edge_torch/_convert/converter.py", line 21, in > -> from ai_edge_torch._convert import conversion -> File "/home//Workspace/tflite/env3_10/lib/python3.10/site-packages/ai_edge_torch/_convert/conversion.py", line 23, in > -> from ai_edge_torch._convert import fx_passes -> File "/home//Workspace/tflite/env3_10/lib/python3.10/site-packages/ai_edge_torch/_convert/fx_passes/__init__.py", line 21, > in -> from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass -> . -> . -> . -> ImportError: cannot import name 'AttrsDescriptor' from 'triton.compiler.compiler' (/home//Workspace/tflite/env3_10/lib/> python3.10/site-packages/triton/compiler/compiler.py) -> ``` -> Please use triton 3.2.0 as the following: -> ```bash -> pip install triton==3.2.0 -> ``` - - -### Convert Conditioners Submodule -The Conditioners submodule is based on the T5Encoder model. Convert it first to ONNX, then to LiteRT format. All details are implemented in [`scripts/export_conditioners.py`](./export_conditioners.py), which includes the following steps: - - 1. Load the Conditioners submodule from the Stable Audio Open Small Model configuration and checkpoint. - 2. Export the Conditioners submodule to ONNX via `torch.onnx.export()`. - 3. Convert the resulting `.onnx` file to LiteRT using `onnx2tf`. - -The two conversion steps (PyTorch -> ONNX and ONNX -> LiteRT) are defined as follows: - - PyTorch -> ONNX -```python -# Export to ONNX -torch.onnx.export( - model, - example_inputs, - output_path, - input_names=[], #Model inputs, a list of input tensors - output_names=[], #Model outputs, a list of output tensors - opset_version=15, - ) -``` +# Install numpy with this version +pip install --no-deps "numpy==1.26.4" - ONNX -> LiteRT -```bash -# Conversion to LiteRT format -onnx2tf -i "input_onnx_model_path" -o "output_folder_path" -``` -_or within a Python script_: -```python -import subprocess - -onnx2tf_command = [ - "onnx2tf", - "-i", str(input_onnx_model_path), - "-o", str(output_folder_path), -] -# Call the command line tool -subprocess.run(onnx2tf_command, check=True) ``` -Converting an `.onnx` model to `.tflite`, creates a folder containing models with different precisions (e.g., float16, float32). You will be using the float32.tflite model for on-device inference. -To run the [`scripts/export_conditioners.py`](./export_conditioners.py) script, use the following command (ensure your .venv is still active): - -```bash -python3 ./scripts/export_conditioners.py --model_config "$WORKSPACE/model_config.json" --ckpt_path "$WORKSPACE/model.ckpt" -``` - -### Convert DiT and AutoEncoder Submodules -To convert the DiT and AutoEncoder submodules, we use the [Generative API](https://github.com/google-ai-edge/ai-edge-torch/tree/main/ai_edge_torch/generative/) provided in by the `ai-edge-torch` tools. This API supports exporting a PyTorch model directly to LiteRT following three mains steps; model re-authoring, quantization, and finally conversion. +### Exporting the models +To convert the models, we use the [Generative API](https://github.com/google-ai-edge/litert-torch/tree/main/litert_torch/generative) provided in by the `litert_torch` tools. This API supports exporting a PyTorch model directly to LiteRT following three mains steps; model re-authoring, quantization, and finally conversion. Here is a code snippet illustrating how the API works in practice. ```python -import ai_edge_torch -from ai_edge_torch.generative.quantize import quant_recipe +import litert_torch +from litert_torch.quantize import quant_config +from litert_torch.generative.quantize import quant_recipe, quant_recipe_utils + # Specify the quantization format -quant_config = quant_recipes.full_int8_dynamic_recipe() +quant_config_int8 = quant_config.QuantConfig( + generative_recipe=quant_recipe.GenerativeQuantRecipe( + default=quant_recipe_utils.create_layer_quant_dynamic(), + ) +) # Initiate the conversion edge_model = ai_edge_torch.convert( - model, example_inputs, quant_config=quant_config + model, example_inputs, quant_config=quant_config_int8 ) ``` -Notes on the arguments for `ai_edge_torch.convert()`: +Notes on the arguments for `litert_torch.convert()`: - __model__: The PyTorch model to be converted. This should be the pre-trained model loaded from the `.config` and `.ckpt` files, and set to evaluation mode (model.eval()). - __example_inputs__: A tuple of torch.Tensor objects. These are dummy input tensors that match the expected shape and type of your model's forward pass arguments. For models with multiple inputs, provide them as a tuple in the correct order. -To convert the DiT and AutoEncoder submodules, run the [`export_dit_autoencoder.py`](./export_dit_autoencoder.py) script using the following command (ensure your .venv is still active): +To convert the models, run the [`export_sao.py`](./export_sao.py) script using the following command (ensure your .venv is still active): ```bash -python3 ./scripts/export_dit_autoencoder.py --model_config "$WORKSPACE/model_config.json" --ckpt_path "$WORKSPACE/model.ckpt" +python3 ./scripts/export_sao.py --model_config "$WORKSPACE/model_config.json" --ckpt_path "$WORKSPACE/model.ckpt" ``` The three LiteRT format models will be required to run the audiogen application on Android™ device. diff --git a/kleidiai-examples/audiogen/scripts/export_conditioners.py b/kleidiai-examples/audiogen/scripts/export_conditioners.py deleted file mode 100644 index f579a08..0000000 --- a/kleidiai-examples/audiogen/scripts/export_conditioners.py +++ /dev/null @@ -1,289 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -# -# SPDX-License-Identifier: Apache-2.0 -# - -# Exporting Conditioners to .onnx and LiteRT format (.tflite) -import argparse -import json -import logging -import os -import subprocess -from typing import Any - -import torch -from utils_load_model import load_model - -logging.basicConfig(level=logging.INFO) - -os.environ["CUDA_VISIBLE_DEVICES"] = "" -DEVICE = torch.device("cpu") - - -## ----------------- Utility Functions ------------------- -def get_conditioners(model: str): - """Load the conditioners module from the AudioGen model. - Args: - model (str): The AudioGen model. - Returns: - sao_t5_cond: The T5 encoder. - sao_seconds_total_cond: The seconds_total conditioner. - """ - cond_model = model.conditioner - t5_cond = cond_model.conditioners["prompt"] - seconds_total_cond = cond_model.conditioners["seconds_total"] - - return t5_cond, seconds_total_cond - - -def get_conditioners_example_input(seconds_total: float, seq_length: int): - """Provide example input tensors for the AudioGen Conditioners submodule. - Args: - seconds_total (float): The total seconds for the audio. - seq_length (int): The sequence length for the T5 encoder. - Returns: - input_ids (torch.Tensor): The input IDs tensor for the T5 encoder. - attention_mask (torch.Tensor): The attention mask tensor for the T5 encoder. - seconds_total (torch.Tensor): The seconds_total tensor. - """ - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained("t5-base") - encoded = tokenizer( - text="birds singing in the morning", - truncation=True, - max_length=seq_length, - padding="max_length", - return_tensors="pt", - ) - - # Create the input_ids and attention_mask tensors for sao conditioners - input_ids = encoded["input_ids"] - attention_mask = encoded["attention_mask"] - - # Create the seconds_total tensor - seconds_total = torch.tensor([seconds_total], dtype=torch.float) - - return ( - input_ids, - attention_mask, - seconds_total, - ) - - -def get_conditioners_module(model): - """ - Wrap both the T5 encoder and seconds_total conditioner in a single module. - """ - # Load the SAO conditioners - sao_t5_cond, sao_seconds_total_cond = get_conditioners(model) - - # Return the conditioners module - return ConditionersModule( - sao_t5_cond=sao_t5_cond, - sao_seconds_total_cond=sao_seconds_total_cond, - ) - - -def convert_conditioners_to_onnx(model, example_inputs, output_path): - """Convert the Pytorch Conditioners model to ONNX format. - Args: - model (torch.nn.Module): The Pytorch model to convert. - example_inputs (tuple): A tuple of example input tensors for the model. - output_path (str): The path to save the converted ONNX model. - Returns: - str: The path to the converted ONNX model. - """ - # Export the model to ONNX format - torch.onnx.export( - model, - example_inputs, - output_path, - input_names=["input_ids", "attention_mask", "seconds_total"], - output_names=["cross_attention_input", "cross_attention_masks", "global_cond"], - opset_version=15, - ) - print(f"Model exported to {output_path}") - return output_path - - -## ----------------- Wrapper Classes ------------------- -class ExportableNumberConditioner(torch.nn.Module): - """NumberConditioner Module. Take a list of floats, - normalizes them for a given range, and returns a list of embeddings. - """ - - def __init__( - self, - numberConditioner, - ): - super(ExportableNumberConditioner, self).__init__() - - self.min_val = numberConditioner.min_val - self.max_val = numberConditioner.max_val - - self.embedder = numberConditioner.embedder - - def forward(self, floats: torch.tensor) -> Any: - floats = floats.clamp(self.min_val, self.max_val) - - normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) - - # Cast floats to same type as embedder - embedder_dtype = next(self.embedder.parameters()).dtype - normalized_floats = normalized_floats.to(embedder_dtype) - - float_embeds = self.embedder(normalized_floats).unsqueeze(1) - - return float_embeds, torch.ones(float_embeds.shape[0], 1) - -class ConditionersModule(torch.nn.Module): - """Conditioners Module. Takes the T5 encoder and seconds_total conditioner, - and returns the cross-attention inputs and global conditioning inputs. - """ - - def __init__( - self, - sao_t5_cond: torch.nn.Module, - sao_seconds_total_cond: torch.nn.Module, - ): - super().__init__() - self.sao_t5 = sao_t5_cond - self.sao_seconds_total_cond = ExportableNumberConditioner( - sao_seconds_total_cond - ) - - # Use float - self.sao_t5 = ( - self.sao_t5.to("cpu").to(dtype=torch.float).eval().requires_grad_(False) - ) - self.sao_seconds_total_cond = self.sao_seconds_total_cond.to(dtype=torch.float) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - seconds_total: torch.Tensor, - ): - # Get the projections and conditioner results - with torch.no_grad(): - t5_embeddings = self.sao_t5.model( - input_ids=input_ids, attention_mask=attention_mask - )["last_hidden_state"] - # Resize the embeddings and attention mask to 64 to match DiT model - t5_embeddings = t5_embeddings[:, :64, :] - attention_mask = attention_mask[:, :64] - # Get the T5 projections - t5_proj = self.sao_t5.proj_out(t5_embeddings.float()) - t5_proj = t5_proj * attention_mask.unsqueeze(-1).float() - t5_mask = attention_mask.float() - - # Get seconds_total conditioner results - seconds_total_embedding, seconds_total_mask = self.sao_seconds_total_cond( - seconds_total - ) - - # Concatenate all cross-attention inputs (t5_embedding, seconds_total) over the sequence dimension - # Assumes that the cross-attention inputs are of shape (batch, seq, channels) - cross_attention_input = torch.cat( - [ - t5_proj, - seconds_total_embedding, - ], - dim=1, - ) - cross_attention_masks = torch.cat( - [ - t5_mask, - seconds_total_mask, - ], - dim=1, - ) - - # Concatenate all global conditioning inputs (seconds_start, seconds_total) over the channel dimension - # Assumes that the global conditioning inputs are of shape (batch, channels) - global_cond = torch.cat( - [ - seconds_total_embedding - ], - dim=-1, - ) - global_cond = global_cond.squeeze(1) - - return cross_attention_input, cross_attention_masks, global_cond - - -## ----------------- Exporting Conditioners to LiteRT format ------------------- -def export_conditioners(args) -> None: - """Export the conditioners of the AudioGen model to LiteRT format.""" - - model_config = None - dtype = torch.float32 - - # Load the model configuration - logging.info("Loading the AudioGen Checkpoint...") - with open(args.model_config, encoding="utf-8") as f: - model_config = json.load(f) - model, model_config = load_model( - model_config, - args.ckpt_path, - pretrained_name=None, - device=DEVICE, - ) - - # Load the conditioners and the t5 model - conditioners = get_conditioners_module(model=model) - conditioners = conditioners.to(dtype).eval().requires_grad_(False) - conditioners_example_input = get_conditioners_example_input( - seq_length=128, seconds_total=10.0 - ) - # Export the conditioners first to ONNX - logging.info("Starting Conditioners export to ONNX...\n") - onnx_model_path = convert_conditioners_to_onnx( - conditioners, - conditioners_example_input, - output_path="./conditioners.onnx", - ) - logging.info( - "Conditioners in ONNX format has been saved to %s/conditioners.onnx", - onnx_model_path, - ) - logging.info("Starting ONNX to LiteRT conversion...\n") - # Convert the ONNX model to LiteRT format - Use command line for faster conversion - onnx2tf_command = [ - "onnx2tf", - "-i", - str(onnx_model_path), - "-o", - "./conditioners_tflite", - ] - # Call the command line tool - subprocess.run(onnx2tf_command, check=True) - logging.info( - "Conditioners in LiteRT format has been saved to %s/conditioners_tflite", - ) - - -def main(): - """Main function to export the AudioGen Conditioners model to onnx and then LiteRT format.""" - parser = argparse.ArgumentParser() - # Export the model to ONNX and LiteRT format - parser.add_argument( - "-m", - "--model_config", - type=str, - help="Path to the model configuration file.", - required=True, - ) - parser.add_argument( - "--ckpt_path", - type=str, - help="Path to the model checkpoint file.", - required=True, - ) - export_conditioners(parser.parse_args()) - - -if __name__ == "__main__": - main() diff --git a/kleidiai-examples/audiogen/scripts/export_dit_autoencoder.py b/kleidiai-examples/audiogen/scripts/export_dit_autoencoder.py deleted file mode 100644 index 4be0e8e..0000000 --- a/kleidiai-examples/audiogen/scripts/export_dit_autoencoder.py +++ /dev/null @@ -1,246 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates -# -# SPDX-License-Identifier: Apache-2.0 -# - -# Disable GPU to avoid any issues during export -import os -os.environ["CUDA_VISIBLE_DEVICES"] = "" - -import argparse -import json -import logging - -import litert_torch -import torch - -from einops import rearrange - -from litert_torch.generative.quantize import quant_recipe, quant_recipe_utils -from litert_torch.quantize import quant_config -from utils_load_model import load_model - -import stable_audio_tools - -torch.manual_seed(0) -DEVICE = torch.device("cpu") - -logging.basicConfig(level=logging.INFO) - -## ----------------- Utility Functions DiT ------------------- -def get_dit_example_input_mapping(dtype=torch.float): - """Provide example input tensors for the DiT model as a dictionary. - Args: - dtype (torch.dtype): The data type for the input tensors. - Returns: - dict: A dictionary containing the example input tensors for the DiT model. - x (torch.Tensor): The input tensor for the DiT model. - t (torch.Tensor): The time tensor for the DiT model. - cross_attn_cond (torch.Tensor): The cross attention conditioning tensor for the DiT model. Output of the Conditioner T5 Encoder. - global_cond (torch.Tensor): The global conditioning tensor for the DiT model. Output of the Conditioner Number Encoder. - """ - return { - "x": torch.rand(size=(1, 64, 256), dtype=dtype, requires_grad=False), # x - "t": torch.tensor([0.154], dtype=dtype, requires_grad=False), # t - "cross_attn_cond": torch.rand( - size=(1, 65, 768), dtype=dtype, requires_grad=False - ), # cross_attn_cond - "global_cond": torch.rand(size=(1, 768), dtype=dtype, requires_grad=False), # global_cond - } - - -## ----------------- Utility Functions AutoEncoder ------------------- -def get_autoencoder_decoder_module(model): - """Get the AutoEncoder module from the AudioGen model.""" - return AutoEncoderDecoderModule(model.pretransform) - -def get_autoencoder_encoder_module(model): - """Get the AutoEncoder module from the AudioGen model.""" - return AutoEncoderEncoderModule(model.pretransform) - -def get_autoencoder_decoder_example_input(dtype=torch.float): - """Get example input for the AutoEncoder module.""" - return (torch.rand((1, 64, 256), dtype=dtype),) - -def get_autoencoder_encoder_example_input(dtype=torch.float): - """Get example input for the AutoEncoder module.""" - return (torch.rand((1, 2, 524288), dtype=dtype),) - - -class AutoEncoderDecoderModule(torch.nn.Module): - """Wrap the AutoEncoder Module. Takes the AutoEncoder and returns the audio. - Args: - autoencoder (torch.nn.Module): The AutoEncoder module. - Returns: - audio (torch.Tensor): The decoded audio tensor. - """ - - def __init__(self, autoencoder): - super(AutoEncoderDecoderModule, self).__init__() - self.autoencoder = autoencoder - - # Use float - self.autoencoder = ( - self.autoencoder.to(dtype=torch.float).eval().requires_grad_(False) - ) - - def forward(self, sampled: torch.Tensor): - dtype = torch.float - sampled_uncompressed = self.autoencoder.decode(sampled.to(dtype)) - - audio = rearrange(sampled_uncompressed, "b d n -> d (b n)") - - return audio - -def vae_sample_updated(mean, scale): - stdev = torch.nn.functional.softplus(scale) + 1e-4 - var = stdev * stdev - logvar = torch.log(var) - - # "randn_like" was causing failures while exporting the model: - # latents = torch.randn_like(mean) * stdev + mean - rand = torch.randn(mean.size()) - latents = rand * stdev + mean - - kl = (mean * mean + var - logvar - 1).sum(1).mean() - - return latents, kl - -class AutoEncoderEncoderModule(torch.nn.Module): - """Wrap the AutoEncoder Module. Takes the AutoEncoder and returns the audio. - Args: - autoencoder (torch.nn.Module): The AutoEncoder module. - Returns: - audio (torch.Tensor): The decoded audio tensor. - """ - - def __init__(self, autoencoder): - super(AutoEncoderEncoderModule, self).__init__() - self.autoencoder = autoencoder - - # Use float - self.autoencoder = ( - self.autoencoder.to(dtype=torch.float).eval().requires_grad_(False) - ) - - stable_audio_tools.models.bottleneck.vae_sample = vae_sample_updated - - def forward(self, sampled: torch.Tensor): - dtype = torch.float - sample_compressed = self.autoencoder.encode(sampled.to(dtype)) - - return sample_compressed - -## ----------------- Exporting DiT and AutoEncoder to LiteRT format ------------------- -def export_audiogen(args) -> None: - """Export the Dit and AutoEncoder of the AudioGen model to LiteRT format.""" - - model_config = None - dtype = torch.float32 - - # Load the model Configuration - logging.info("Loading the AudioGen Checkpoint...") - with open(args.model_config, encoding="utf-8") as f: - model_config = json.load(f) - model, model_config = load_model( - model_config, - args.ckpt_path, - pretrained_name=None, - device=DEVICE, - ) - - logging.info( - "Exporting the model, ckpt: %s, with config %s.", - args.ckpt_path, - args.model_config, - ) - - # Create the dynamic weights int8 quantization config - quant_config_audiogen_int8 = quant_config.QuantConfig( - generative_recipe=quant_recipe.GenerativeQuantRecipe( - default=quant_recipe_utils.create_layer_quant_dynamic(), - ) - ) - - ## --------- DiT Model --------- - # Load the diffusion transformer model (DiT) - logging.info("Starting DiT Model conversion to LiteRT format...\n") - dit_model = model.model - dit_model = dit_model.to(dtype).eval().requires_grad_(False) - dit_model_example_input = get_dit_example_input_mapping(dtype) - logging.info("Exporting the DiT model...") - - # # Workaround for some issue in LiteRT that occurs at runtime - rotary_pos_emb_res = ( - dit_model.model.transformer.rotary_pos_emb.forward_from_seq_len(257) - ) - def rotary_emb_const(_): - return rotary_pos_emb_res - dit_model.model.transformer.rotary_pos_emb.forward_from_seq_len = rotary_emb_const - - # Export the DiT to LiteRT format - edge_model = litert_torch.convert( - dit_model, sample_args=None, sample_kwargs=dit_model_example_input, quant_config=quant_config_audiogen_int8 - ) - edge_model.export("./dit_model.tflite") - logging.info("DiT model has been saved to %s/dit_model.tflite") - - ## --------- AutoEncoder Decoder Model --------- - # Load the Encoder part of the AutoEncoder - logging.info("Starting AutoEncoder Decoder Model conversion to LiteRT format...\n") - autoencoder_decoder = get_autoencoder_decoder_module(model) - autoencoder_decoder = autoencoder_decoder.to(dtype).eval().requires_grad_(False) - autoencoder_decoder_example_input = get_autoencoder_decoder_example_input(dtype) - - # Export the Encoder part of the AutoEncoder to LiteRT format - edge_model = litert_torch.convert( - autoencoder_decoder, - autoencoder_decoder_example_input, - ) - edge_model.export("./autoencoder_model.tflite") - logging.info( - "AutoEncoder model has been saved to %s/autoencoder_model.tflite", - ) - - ## --------- AutoEncoder Encoder Model --------- - # Load the Encoder part of the AutoEncoder - logging.info("Starting AutoEncoder Encoder Model conversion to LiteRT format...\n") - autoencoder_encoder = get_autoencoder_encoder_module(model) - autoencoder_encoder = autoencoder_encoder.to(dtype).eval().requires_grad_(False) - autoencoder_encoder_example_input = get_autoencoder_encoder_example_input(dtype) - - # Export the AutoEncoder to LiteRT format - edge_model = litert_torch.convert( - autoencoder_encoder, - autoencoder_encoder_example_input, - ) - edge_model.export("./autoencoder_encoder_model.tflite") - logging.info( - "AutoEncoder model has been saved to %s/autoencoder_encoder_model.tflite", - ) - - -def main(): - """Main function to export the AudioGen model to LiteRT format.""" - parser = argparse.ArgumentParser() - - parser.add_argument( - "-m", - "--model_config", - type=str, - help="Path to model config", - required=True - ) - parser.add_argument( - "-p", - "--ckpt_path", - type=str, - help="Path to model checkpoint", - required=True - ) - export_audiogen(parser.parse_args()) - - -if __name__ == "__main__": - main() diff --git a/kleidiai-examples/audiogen/scripts/export_sao.py b/kleidiai-examples/audiogen/scripts/export_sao.py new file mode 100644 index 0000000..1a1e58f --- /dev/null +++ b/kleidiai-examples/audiogen/scripts/export_sao.py @@ -0,0 +1,181 @@ +# +# SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +import argparse +import json +import logging +import os + +import torch + +from model import (get_dit_module, load_model, + get_autoencoder_decoder_module, + get_autoencoder_decoder_example_input, + get_autoencoder_encoder_module, + get_autoencoder_encoder_example_input, + get_conditioners_module, + get_conditioners_example_input, + get_dit_example_input_mapping) + +from stable_audio_tools.models.utils import remove_weight_norm_from_model + +import litert_torch + +from litert_torch.generative.quantize import quant_recipe, quant_recipe_utils +from litert_torch.quantize import quant_config + +logging.basicConfig(level=logging.INFO) + +os.environ["CUDA_VISIBLE_DEVICES"] = "" + +T5_SEQ_LENGTH = 128 + +def export_conditioners(model, output_path) -> None: + + with torch.no_grad(): + conditioners_model = get_conditioners_module(model) + conditioners_model = conditioners_model.eval().requires_grad_(False) + conditioners_example_input = get_conditioners_example_input(seq_length=T5_SEQ_LENGTH, seconds_total=10.0) + + edge_model = litert_torch.convert( + conditioners_model, sample_args=conditioners_example_input, sample_kwargs=None + ) + + edge_model.export(os.path.join(output_path, "conditioners_float32.tflite")) + logging.info("Conditioners model has been saved to %s", os.path.abspath(os.path.join(output_path, "conditioners_float32.tflite"))) + +def export_dit(model, output_path, dtype = torch.float) -> None: + + logging.info("Starting DiT Model conversion to LiteRT format...\n") + + with torch.no_grad(): + dit_model = get_dit_module(model=model) + dit_model = dit_model.to(dtype).eval().requires_grad_(False) + dit_model_example_input = get_dit_example_input_mapping(dtype) + + # Create the dynamic weights int8 quantization config + quant_config_audiogen_int8 = quant_config.QuantConfig( + generative_recipe=quant_recipe.GenerativeQuantRecipe( + default=quant_recipe_utils.create_layer_quant_dynamic(), + ) + ) + + # Workaround for some issue in LiteRT that occurs at runtime + rotary_pos_emb_res = ( + dit_model.model.transformer.rotary_pos_emb.forward_from_seq_len(257) + ) + def rotary_emb_const(_): + return rotary_pos_emb_res + dit_model.model.transformer.rotary_pos_emb.forward_from_seq_len = rotary_emb_const + + # Export the DiT to LiteRT format + edge_model = litert_torch.convert( + dit_model, sample_args=None, sample_kwargs=dit_model_example_input, quant_config=quant_config_audiogen_int8 + ) + + edge_model.export(os.path.join(output_path, "dit_model.tflite")) + logging.info("DiT model has been saved to %s", os.path.abspath(os.path.join(output_path, "dit_model.tflite"))) + +def export_autoencoder(model, output_path, dtype = torch.float) -> None: + + logging.info("Starting AutoEncoder Decoder conversion...\n") + + with torch.no_grad(): + autoencoder_decoder_example_input = get_autoencoder_decoder_example_input(dtype=dtype) + # model.pretransform.model_half=False + model = model.to(dtype).eval().requires_grad_(False) + + autoencoder_decoder = get_autoencoder_decoder_module(model) + autoencoder_decoder = autoencoder_decoder.to(dtype).eval().requires_grad_(False) + + # Export the model to LiteRT format + edge_model = litert_torch.convert( + autoencoder_decoder, sample_args=autoencoder_decoder_example_input, + ) + edge_model.export(os.path.join(output_path, "autoencoder_model.tflite")) + logging.info("AutoEncoder Decoder model has been saved to %s", os.path.abspath(os.path.join(output_path, "autoencoder_model.tflite"))) + +def export_autoencoder_encoder(model, output_path, dtype = torch.float) -> None: + + logging.info("Starting AutoEncoder Encoder conversion...\n") + + with torch.no_grad(): + autoencoder_encoder_example_input = get_autoencoder_encoder_example_input(dtype=dtype) + # model.pretransform.model_half=False + model = model.to(dtype).eval().requires_grad_(False) + + autoencoder_encoder = get_autoencoder_encoder_module(model) + autoencoder_encoder = autoencoder_encoder.to(dtype).eval().requires_grad_(False) + + # Export the model to LiteRT format + edge_model = litert_torch.convert( + autoencoder_encoder, sample_args=autoencoder_encoder_example_input, + ) + edge_model.export(os.path.join(output_path, "autoencoder_encoder_model.tflite")) + logging.info("AutoEncoder Encoder model has been saved to %s", os.path.abspath(os.path.join(output_path, "autoencoder_encoder_model.tflite"))) + +def export(args) -> None: + + torch.manual_seed(0) + device = torch.device("cpu") + + # Load the model configuration + logging.info("Loading the AudioGen Checkpoint...") + with open(args.model_config, encoding="utf-8") as f: + model_config = json.load(f) + + # Load the model + model, model_config = load_model( + model_config = model_config, + model_ckpt_path = args.ckpt_path, + pretrained_name=None, + device=device, + ) + logging.info("Model is loaded...") + + # --------- Conditioners Model --------- + export_conditioners(model, args.output_path) + + # --------- DiT Model ---------------- + export_dit(model, args.output_path) + + # --------- AutoEncoder Model --------- + + # Removing weight norm from the model as it is causing issues during export + remove_weight_norm_from_model(model.pretransform) + + export_autoencoder(model, args.output_path) + export_autoencoder_encoder(model, args.output_path) + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--model_config", + type=str, + help="Path to the model configuration file.", + required=True, + ) + parser.add_argument( + "--ckpt_path", + type=str, + help="Path to the model checkpoint file.", + required=True, + ) + + parser.add_argument( + "--output_path", + type=str, + help="Path to the output directory for the exported models.", + default=".", + required=False, + ) + + export(parser.parse_args()) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/kleidiai-examples/audiogen/scripts/model.py b/kleidiai-examples/audiogen/scripts/model.py new file mode 100644 index 0000000..8f2dab5 --- /dev/null +++ b/kleidiai-examples/audiogen/scripts/model.py @@ -0,0 +1,368 @@ +# +# SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +import logging +from typing import Any, Dict, Optional, Tuple + +import torch +from einops import rearrange + +import stable_audio_tools +from stable_audio_tools.models.factory import create_model_from_config +from stable_audio_tools.models.pretrained import get_pretrained_model +from stable_audio_tools.models.utils import load_ckpt_state_dict +from stable_audio_tools.models.utils import copy_state_dict + +DEVICE = torch.device("cpu") + +logging.basicConfig(level=logging.INFO) + +import torch +import torch.nn as nn + +def force_t5_conditioner_float32(t5_cond: nn.Module) -> nn.Module: + """Force the wrapped T5 conditioner into float32, including hidden modules. + + Args: + t5_cond: The T5 conditioner module to be converted to float32. + Returns: + The T5 conditioner module with all parameters and submodules converted to float32. + """ + t5_cond = t5_cond.to(torch.float32) + + if hasattr(t5_cond, "proj_out") and isinstance(t5_cond.proj_out, nn.Module): + t5_cond.proj_out = t5_cond.proj_out.to(torch.float32) + + hidden_model = getattr(t5_cond, "model", None) + if isinstance(hidden_model, nn.Module): + t5_cond.__dict__["model"] = ( + hidden_model.to(torch.float32).eval().requires_grad_(False) + ) + + return t5_cond + + +## Model loading +def load_model( + model_config: Optional[Dict[str, Any]] = None, + model_ckpt_path: Optional[str] = None, + pretrained_name: Optional[str] = None, + pretransform_ckpt_path: Optional[str] = None, + device: torch.device = DEVICE, +) -> Tuple[torch.nn.Module, Dict[str, Any]]: + """Load the AudioGen model and its configuration. + + Either a pretrained model (via `pretrained_name`) or a freshly constructed one + (via `model_config` + `model_ckpt_path`) will be loaded. + + Args: + model_config: Configuration dict for creating the model. + model_ckpt_path: Path to a model checkpoint file. + pretrained_name: Name of a model to load from the repo. + pretransform_ckpt_path: Optional path to a pretransform checkpoint. + device: Torch device to map the model to. + + Returns: + A tuple of (model, model_config), where `model` is in eval mode + and cast to float, and `model_config` contains sample_rate/size, etc. + """ + + if pretrained_name is not None: + logging.info("Loading pretrained model: %s", pretrained_name) + model, model_config = get_pretrained_model(pretrained_name) + + elif model_config is not None: + if model_ckpt_path is None: + raise ValueError( + "model_ckpt_path must be provided when specifying model_config" + ) + logging.info("Creating model from config") + model = create_model_from_config(model_config) + + logging.info("Loading model checkpoint from: %s", model_ckpt_path) + + # Load checkpoint + copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) + logging.info("Done loading model checkpoint") + + if pretransform_ckpt_path is not None: + logging.info("Loading pretransform checkpoint from %r", pretransform_ckpt_path) + model.pretransform.load_state_dict( + load_ckpt_state_dict(pretransform_ckpt_path), strict=False + ) + logging.info("Done loading pretransform.") + + model.to(device).eval().requires_grad_(False) + model.pretransform.model_half = False + model = model.to(torch.float) + + return model, model_config + + +## Utility functions for conditioners +def get_conditioners(model): + """Load the conditioners module from Stable Audio Open Small model. + Args: + model: Stable Audio Open Small model. + Returns: + sao_t5_cond: The T5 encoder. + sao_seconds_total_cond: The seconds_total conditioner. + """ + cond_model = model.conditioner + t5_cond = force_t5_conditioner_float32(cond_model.conditioners["prompt"]) + seconds_total_cond = cond_model.conditioners["seconds_total"] + + return t5_cond, seconds_total_cond + +# Wrapper class for number conditioner +class ExportableNumberConditioner(torch.nn.Module): + """NumberConditioner Module. Take a list of floats, + normalizes them for a given range, and returns a list of embeddings. + """ + + def __init__( + self, + numberConditioner, + ): + super(ExportableNumberConditioner, self).__init__() + + self.min_val = numberConditioner.min_val + self.max_val = numberConditioner.max_val + + self.embedder = numberConditioner.embedder + + def forward(self, floats: torch.tensor) -> Any: + floats = floats.clamp(self.min_val, self.max_val) + + normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) + + # Cast floats to same type as embedder + embedder_dtype = next(self.embedder.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + float_embeds = self.embedder(normalized_floats).unsqueeze(1) + + return float_embeds, torch.ones(float_embeds.shape[0], 1) + +class ConditionersModule(torch.nn.Module): + """Conditioners Module. Takes the T5 encoder and seconds_total conditioner, + and returns the cross-attention inputs and global conditioning inputs. + """ + + def __init__( + self, + sao_t5_cond: torch.nn.Module, + sao_seconds_total_cond: torch.nn.Module, + dtype: torch.dtype = torch.float + ): + super().__init__() + self.sao_t5 = sao_t5_cond + self.sao_seconds_total_cond = ExportableNumberConditioner( + sao_seconds_total_cond + ) + self.dtype = dtype + + # Use float + self.sao_t5 = force_t5_conditioner_float32(self.sao_t5.to("cpu")) + self.sao_t5 = self.sao_t5.to(dtype).eval().requires_grad_(False) + self.sao_seconds_total_cond = self.sao_seconds_total_cond.to(dtype=dtype) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + seconds_total: torch.Tensor, + ): + # Get the projections and conditioner results + with torch.no_grad(): + t5_embeddings = self.sao_t5.model( + input_ids=input_ids, attention_mask=attention_mask + )["last_hidden_state"] + # Resize the embeddings and attention mask to 64 to match DiT model + t5_embeddings = t5_embeddings[:, :64, :] + attention_mask = attention_mask[:, :64] + # Get the T5 projections + t5_proj = self.sao_t5.proj_out(t5_embeddings).to(dtype=self.dtype) + t5_proj = t5_proj * attention_mask.unsqueeze(-1).to(dtype=self.dtype) + t5_mask = attention_mask + + # Get seconds_total conditioner results + seconds_total_embedding, seconds_total_mask = self.sao_seconds_total_cond( + seconds_total + ) + + # Concatenate all cross-attention inputs (t5_embedding, seconds_total) over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + cross_attention_input = torch.cat( + [ + t5_proj, + seconds_total_embedding, + ], + dim=1, + ) + cross_attention_masks = torch.cat( + [ + t5_mask.to(torch.float), + seconds_total_mask.to(torch.float), + ], + dim=1, + ) + + # Concatenate all global conditioning inputs (seconds_start, seconds_total) over the channel dimension + # Assumes that the global conditioning inputs are of shape (batch, channels) + global_cond = torch.cat( + [ + seconds_total_embedding + ], + dim=-1, + ) + global_cond = global_cond.squeeze(1) + + return cross_attention_input, cross_attention_masks, global_cond + +def get_conditioners_module(model, dtype = torch.float): + """ + Wrap both the T5 encoder and seconds_total conditioner in a single module. + """ + # Load the SAO conditioners + sao_t5_cond, sao_seconds_total_cond = get_conditioners(model) + + # Return the conditioners module + return ConditionersModule( + sao_t5_cond=sao_t5_cond, + sao_seconds_total_cond=sao_seconds_total_cond, + dtype=dtype + ) + +def get_conditioners_example_input(seconds_total: float = 11, seq_length: int = 64, dtype=torch.float): + """Provide example input tensors for the AudioGen Conditioners submodule. + Args: + seconds_total (float): The total seconds for the audio. + seq_length (int): The sequence length for the T5 encoder. + Returns: + input_ids (torch.Tensor): The input IDs tensor for the T5 encoder. + attention_mask (torch.Tensor): The attention mask tensor for the T5 encoder. + seconds_total (torch.Tensor): The seconds_total tensor. + """ + + # Create the input_ids and attention_mask tensors for sao conditioners + input_ids = torch.randint(1, 64, (1,seq_length)).to(torch.int64) + attention_mask = torch.ones((1, seq_length), dtype=torch.int64) + + # Create the seconds_total tensor + seconds_total = torch.tensor([seconds_total], dtype=dtype) + + return ( + input_ids, + attention_mask, + seconds_total, + ) + +## Utility functions for DiT +def get_dit_example_input_mapping(dtype=torch.float): + """Provide example input tensors for the DiT model as a dictionary. + Args: + dtype (torch.dtype): The data type for the input tensors. + Returns: + dict: A dictionary containing the example input tensors for the DiT model. + x (torch.Tensor): The input tensor for the DiT model. + t (torch.Tensor): The time tensor for the DiT model. + cross_attn_cond (torch.Tensor): The cross attention conditioning tensor for the DiT model. Output of the Conditioner T5 Encoder. + global_cond (torch.Tensor): The global conditioning tensor for the DiT model. Output of the Conditioner Number Encoder. + """ + return { + "x": torch.rand(size=(1, 64, 256), dtype=dtype, requires_grad=False), # x + "t": torch.tensor([0.154], dtype=dtype, requires_grad=False), # t + "cross_attn_cond": torch.rand( + size=(1, 65, 768), dtype=dtype, requires_grad=False + ), # cross_attn_cond + "global_cond": torch.rand(size=(1, 768), dtype=dtype, requires_grad=False), # global_cond + } + +def get_dit_module(model, dtype = torch.float32): + dit_model = model.model + dit_model = dit_model.to(dtype).eval().requires_grad_(False) + return dit_model + + +## Utility functions for AutoEncoder +def get_autoencoder_decoder_module(model): + """Get the AutoEncoder module from the AudioGen model.""" + return AutoEncoderDecoderModule(model.pretransform) + +def get_autoencoder_decoder_example_input(dtype=torch.float): + """Get example input for the AutoEncoder module.""" + return (torch.rand((1, 64, 256), dtype=dtype),) + +class AutoEncoderDecoderModule(torch.nn.Module): + """Wrap the AutoEncoder Module. Takes the AutoEncoder and returns the audio. + Args: + autoencoder (torch.nn.Module): The AutoEncoder module. + Returns: + audio (torch.Tensor): The decoded audio tensor. + """ + + def __init__(self, autoencoder): + super(AutoEncoderDecoderModule, self).__init__() + + self.autoencoder = ( + autoencoder.to(dtype=torch.float).eval().requires_grad_(False) + ) + + def forward(self, sampled: torch.Tensor): + + sampled_uncompressed = self.autoencoder.decode(sampled) + + audio = rearrange(sampled_uncompressed, "b d n -> d (b n)") + return audio + +## Utility functions for the encoder part of the autoencoder +def get_autoencoder_encoder_module(model): + """Get the AutoEncoder module from the AudioGen model.""" + return AutoEncoderEncoderModule(model.pretransform) + +def get_autoencoder_encoder_example_input(dtype=torch.float): + """Get example input for the AutoEncoder module.""" + return (torch.rand((1, 2, 524288), dtype=dtype),) + +def vae_sample_updated(mean, scale): + stdev = torch.nn.functional.softplus(scale) + 1e-4 + var = stdev * stdev + logvar = torch.log(var) + + # "randn_like" was causing failures while exporting the model: + # latents = torch.randn_like(mean) * stdev + mean + rand = torch.randn(mean.size()) + latents = rand * stdev + mean + + kl = (mean * mean + var - logvar - 1).sum(1).mean() + + return latents, kl + +class AutoEncoderEncoderModule(torch.nn.Module): + """Wrap the AutoEncoder Module. Takes the AutoEncoder and returns the audio. + Args: + autoencoder (torch.nn.Module): The AutoEncoder module. + Returns: + audio (torch.Tensor): The decoded audio tensor. + """ + + def __init__(self, autoencoder): + super(AutoEncoderEncoderModule, self).__init__() + self.autoencoder = autoencoder + + # Use float + self.autoencoder = ( + self.autoencoder.to(dtype=torch.float).eval().requires_grad_(False) + ) + + stable_audio_tools.models.bottleneck.vae_sample = vae_sample_updated + + def forward(self, sampled: torch.Tensor): + dtype = torch.float + sample_compressed = self.autoencoder.encode(sampled.to(dtype)) + + return sample_compressed \ No newline at end of file diff --git a/kleidiai-examples/audiogen/scripts/utils_load_model.py b/kleidiai-examples/audiogen/scripts/utils_load_model.py deleted file mode 100644 index 9dc1e36..0000000 --- a/kleidiai-examples/audiogen/scripts/utils_load_model.py +++ /dev/null @@ -1,97 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates -# -# SPDX-License-Identifier: Apache-2.0 -# - -import logging -from typing import Any, Dict, Optional, Tuple - -import torch - -from stable_audio_tools.models.factory import create_model_from_config -from stable_audio_tools.models.pretrained import get_pretrained_model -from stable_audio_tools.models.utils import load_ckpt_state_dict - -MODEL = None -SAMPLE_RATE = 44100 -SAMPLE_SIZE = 524288 -DEVICE = torch.device("cpu") - - -## ----------------- Loading SAO Model ------------------- -def copy_state_dict(model, state_dict): - """Load state_dict to model, but only for keys that match exactly. - - Args: - model (nn.Module): model to load state_dict. - state_dict (OrderedDict): state_dict to load. - """ - model_state_dict = model.state_dict() - for key in state_dict: - if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape: - if isinstance(state_dict[key], torch.nn.Parameter): - # backwards compatibility for serialized parameters - state_dict[key] = state_dict[key].data - model_state_dict[key] = state_dict[key] - - model.load_state_dict(model_state_dict, strict=False) - -def load_model( - model_config: Optional[Dict[str, Any]] = None, - model_ckpt_path: Optional[str] = None, - pretrained_name: Optional[str] = None, - pretransform_ckpt_path: Optional[str] = None, - device: torch.device = DEVICE, -) -> Tuple[torch.nn.Module, Dict[str, Any]]: - """Load the AudioGen model and its configuration. - - Either a pretrained model (via `pretrained_name`) or a freshly constructed one - (via `model_config` + `model_ckpt_path`) will be loaded. - - Args: - model_config: Configuration dict for creating the model. - model_ckpt_path: Path to a model checkpoint file. - pretrained_name: Name of a model to load from the repo. - pretransform_ckpt_path: Optional path to a pretransform checkpoint. - device: Torch device to map the model to. - - Returns: - A tuple of (model, model_config), where `model` is in eval mode - and cast to float, and `model_config` contains sample_rate/size, etc. - """ - global MODEL, SAMPLE_RATE, SAMPLE_SIZE - - if pretrained_name is not None: - logging.info("Loading pretrained model: %s", pretrained_name) - model, model_config = get_pretrained_model(pretrained_name) - - elif model_config is not None: - if model_ckpt_path is None: - raise ValueError( - "model_ckpt_path must be provided when specifying model_config" - ) - logging.info("Creating model from config") - model = create_model_from_config(model_config) - - logging.info("Loading model checkpoint from: %s", model_ckpt_path) - # Load checkpoint - copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) - logging.info("Done loading model checkpoint") - - SAMPLE_RATE = model_config["sample_rate"] - SAMPLE_SIZE = model_config["sample_size"] - - if pretransform_ckpt_path is not None: - logging.info("Loading pretransform checkpoint from %r", pretransform_ckpt_path) - model.pretransform.load_state_dict( - load_ckpt_state_dict(pretransform_ckpt_path), strict=False - ) - logging.info("Done loading pretransform.") - - model.to(device).eval().requires_grad_(False) - model = model.to(torch.float) - model.pretransform.model_half=False - - print("Done loading model") - return model, model_config