From 6c6ac7e3ba8242ba9bbe93bb39d6ea3c277bbcfc Mon Sep 17 00:00:00 2001 From: StarLight63 <13123813+starlight63@user.noreply.gitee.com> Date: Mon, 18 May 2026 17:10:13 +0800 Subject: [PATCH 1/2] feat/audioOps --- runtime/ops/mapper/__init__.py | 25 +- .../ops/mapper/audio_anomaly_filter/README.md | 41 + .../mapper/audio_anomaly_filter/__init__.py | 6 + .../mapper/audio_anomaly_filter/audio_skip.py | 114 + .../mapper/audio_anomaly_filter/metadata.yml | 66 + .../mapper/audio_anomaly_filter/process.py | 221 ++ .../audio_anomaly_filter/requirements.txt | 2 + .../ops/mapper/audio_asr_pipeline/README.md | 62 + .../ops/mapper/audio_asr_pipeline/__init__.py | 6 + .../config/audio_config.yaml | 8 + .../audio_preprocessor/config/eval_wer.yaml | 6 + .../config/merge_asr_by_source.yaml | 6 + .../speechbrain/speechbrain/__init__.py | 71 + .../speechbrain/alignment/__init__.py | 1 + .../speechbrain/alignment/aligner.py | 1494 +++++++++ .../speechbrain/alignment/ctc_segmentation.py | 11 + .../speechbrain/augment/__init__.py | 1 + .../speechbrain/augment/augmenter.py | 544 +++ .../speechbrain/speechbrain/augment/codec.py | 92 + .../speechbrain/augment/freq_domain.py | 399 +++ .../speechbrain/augment/preparation.py | 219 ++ .../speechbrain/augment/time_domain.py | 1540 +++++++++ .../speechbrain/speechbrain/core.py | 1489 +++++++++ .../speechbrain/dataio/__init__.py | 5 + .../speechbrain/dataio/audio_io.py | 228 ++ .../speechbrain/speechbrain/dataio/batch.py | 333 ++ .../speechbrain/speechbrain/dataio/dataio.py | 1417 ++++++++ .../speechbrain/dataio/dataloader.py | 420 +++ .../speechbrain/speechbrain/dataio/dataset.py | 546 +++ .../speechbrain/speechbrain/dataio/encoder.py | 1216 +++++++ .../speechbrain/dataio/iterators.py | 235 ++ .../speechbrain/speechbrain/dataio/legacy.py | 321 ++ .../speechbrain/dataio/preprocess.py | 82 + .../speechbrain/speechbrain/dataio/sampler.py | 845 +++++ .../speechbrain/speechbrain/dataio/wer.py | 201 ++ .../speechbrain/decoders/__init__.py | 6 + .../speechbrain/speechbrain/decoders/ctc.py | 1905 +++++++++++ .../speechbrain/decoders/language_model.py | 11 + .../speechbrain/decoders/scorer.py | 2189 ++++++++++++ .../speechbrain/decoders/seq2seq.py | 2240 +++++++++++++ .../speechbrain/decoders/transducer.py | 648 ++++ .../speechbrain/speechbrain/decoders/utils.py | 158 + .../speechbrain/speechbrain/inference/ASR.py | 1546 +++++++++ .../speechbrain/speechbrain/inference/SLU.py | 144 + .../speechbrain/speechbrain/inference/ST.py | 138 + .../speechbrain/speechbrain/inference/TTS.py | 928 ++++++ .../speechbrain/speechbrain/inference/VAD.py | 965 ++++++ .../speechbrain/inference/__init__.py | 17 + .../speechbrain/inference/classifiers.py | 322 ++ .../speechbrain/inference/diarization.py | 241 ++ .../speechbrain/inference/encoders.py | 272 ++ .../speechbrain/inference/enhancement.py | 373 +++ .../speechbrain/inference/interfaces.py | 694 ++++ .../speechbrain/inference/interpretability.py | 182 + .../speechbrain/inference/metrics.py | 97 + .../speechbrain/inference/separation.py | 129 + .../speechbrain/inference/speaker.py | 133 + .../speechbrain/speechbrain/inference/text.py | 443 +++ .../speechbrain/inference/vocoders.py | 399 +++ .../speechbrain/integrations/README.md | 33 + .../speechbrain/integrations/__init__.py | 7 + .../integrations/alignment/README.md | 31 + .../integrations/alignment/__init__.py | 3 + .../integrations/alignment/ctc_seg.py | 675 ++++ .../integrations/alignment/diarization.py | 1231 +++++++ .../integrations/audio_tokenizers/README.md | 45 + .../integrations/audio_tokenizers/__init__.py | 3 + .../audio_tokenizers/discrete_ssl.py | 408 +++ .../integrations/audio_tokenizers/kmeans.py | 178 + .../speechtokenizer_interface.py | 157 + .../wavtokenizer_interface.py | 168 + .../integrations/decoders/README.md | 30 + .../integrations/decoders/__init__.py | 3 + .../integrations/decoders/kenlm_scorer.py | 321 ++ .../speechbrain/integrations/hdf5/README.md | 30 + .../speechbrain/integrations/hdf5/__init__.py | 7 + .../integrations/hdf5/cached_item.py | 159 + .../integrations/huggingface/README.md | 70 + .../integrations/huggingface/__init__.py | 20 + .../integrations/huggingface/encodec.py | 385 +++ .../integrations/huggingface/gpt.py | 179 + .../integrations/huggingface/hubert.py | 88 + .../integrations/huggingface/huggingface.py | 455 +++ .../integrations/huggingface/labse.py | 116 + .../integrations/huggingface/llama.py | 198 ++ .../integrations/huggingface/mbart.py | 221 ++ .../integrations/huggingface/mert.py | 88 + .../integrations/huggingface/mimi.py | 191 ++ .../integrations/huggingface/nllb.py | 75 + .../integrations/huggingface/textencoder.py | 122 + .../integrations/huggingface/vocos.py | 158 + .../integrations/huggingface/w2v_bert.py | 200 ++ .../integrations/huggingface/wav2vec2.py | 332 ++ .../integrations/huggingface/wavlm.py | 88 + .../integrations/huggingface/weighted_ssl.py | 122 + .../integrations/huggingface/whisper.py | 637 ++++ .../huggingface/wordemb/__init__.py | 1 + .../huggingface/wordemb/transformer.py | 289 ++ .../integrations/huggingface/wordemb/util.py | 72 + .../speechbrain/integrations/k2_fsa/README.md | 38 + .../integrations/k2_fsa/__init__.py | 20 + .../speechbrain/integrations/k2_fsa/align.py | 667 ++++ .../integrations/k2_fsa/graph_compiler.py | 387 +++ .../integrations/k2_fsa/lattice_decoder.py | 453 +++ .../integrations/k2_fsa/lexicon.py | 584 ++++ .../speechbrain/integrations/k2_fsa/losses.py | 134 + .../integrations/k2_fsa/prepare_lang.py | 575 ++++ .../speechbrain/integrations/k2_fsa/utils.py | 168 + .../speechbrain/integrations/models/README.md | 28 + .../integrations/models/__init__.py | 3 + .../integrations/models/sgmse_plus.py | 615 ++++ .../speechbrain/integrations/nlp/README.md | 36 + .../speechbrain/integrations/nlp/__init__.py | 5 + .../integrations/nlp/bgeM3_embeddings.py | 180 + .../speechbrain/integrations/nlp/bleu.py | 105 + .../integrations/nlp/flair_embeddings.py | 150 + .../integrations/nlp/flair_tagger.py | 87 + .../integrations/nlp/spacy_pipeline.py | 144 + .../speechbrain/integrations/numba/README.md | 25 + .../integrations/numba/__init__.py | 18 + .../integrations/numba/transducer_loss.py | 354 ++ .../integrations/tests/test_cached_item.py | 506 +++ .../tests/test_ctc_segmentation.py | 85 + .../speechbrain/integrations/tests/test_k2.py | 458 +++ .../integrations/tests/test_nlp.py | 78 + .../speechbrain/speechbrain/lm/__init__.py | 1 + .../speechbrain/speechbrain/lm/arpa.py | 353 ++ .../speechbrain/speechbrain/lm/counting.py | 166 + .../speechbrain/speechbrain/lm/ngram.py | 210 ++ .../speechbrain/speechbrain/lobes/__init__.py | 9 + .../speechbrain/lobes/beamform_multimic.py | 50 + .../speechbrain/lobes/downsampling.py | 176 + .../speechbrain/speechbrain/lobes/features.py | 862 +++++ .../speechbrain/lobes/models/BESTRQ.py | 128 + .../speechbrain/lobes/models/CRDNN.py | 315 ++ .../speechbrain/lobes/models/Cnn14.py | 422 +++ .../speechbrain/lobes/models/ContextNet.py | 304 ++ .../speechbrain/lobes/models/DiffWave.py | 701 ++++ .../speechbrain/lobes/models/ECAPA_TDNN.py | 636 ++++ .../speechbrain/lobes/models/ESPnetVGG.py | 128 + .../speechbrain/lobes/models/EnhanceResnet.py | 251 ++ .../speechbrain/lobes/models/FastSpeech2.py | 2924 +++++++++++++++++ .../speechbrain/lobes/models/GatedNN.py | 135 + .../speechbrain/lobes/models/HifiGAN.py | 1838 +++++++++++ .../speechbrain/lobes/models/L2I.py | 581 ++++ .../speechbrain/lobes/models/MSTacotron2.py | 754 +++++ .../speechbrain/lobes/models/MetricGAN.py | 195 ++ .../speechbrain/lobes/models/MetricGAN_U.py | 193 ++ .../speechbrain/lobes/models/PIQ.py | 699 ++++ .../speechbrain/lobes/models/RNNLM.py | 124 + .../speechbrain/lobes/models/ResNet.py | 520 +++ .../speechbrain/lobes/models/Tacotron2.py | 1886 +++++++++++ .../speechbrain/lobes/models/VanillaNN.py | 51 + .../speechbrain/lobes/models/Xvector.py | 246 ++ .../speechbrain/lobes/models/__init__.py | 1 + .../speechbrain/lobes/models/beats.py | 2096 ++++++++++++ .../speechbrain/lobes/models/bsq.py | 181 + .../speechbrain/lobes/models/conv_tasnet.py | 622 ++++ .../speechbrain/lobes/models/convolution.py | 320 ++ .../lobes/models/discrete/__init__.py | 6 + .../speechbrain/lobes/models/discrete/dac.py | 1122 +++++++ .../speechbrain/lobes/models/dual_path.py | 1494 +++++++++ .../lobes/models/fairseq_wav2vec.py | 362 ++ .../speechbrain/lobes/models/g2p/__init__.py | 5 + .../speechbrain/lobes/models/g2p/dataio.py | 688 ++++ .../speechbrain/lobes/models/g2p/homograph.py | 681 ++++ .../speechbrain/lobes/models/g2p/model.py | 582 ++++ .../speechbrain/lobes/models/kmeans.py | 11 + .../speechbrain/lobes/models/resepformer.py | 781 +++++ .../speechbrain/lobes/models/segan_model.py | 253 ++ .../lobes/models/transformer/Branchformer.py | 409 +++ .../lobes/models/transformer/Conformer.py | 1153 +++++++ .../lobes/models/transformer/Transformer.py | 1100 +++++++ .../models/transformer/TransformerASR.py | 726 ++++ .../lobes/models/transformer/TransformerLM.py | 187 ++ .../lobes/models/transformer/TransformerSE.py | 104 + .../lobes/models/transformer/TransformerST.py | 437 +++ .../lobes/models/transformer/__init__.py | 5 + .../speechbrain/lobes/models/wav2vec.py | 413 +++ .../speechbrain/speechbrain/log-config.yaml | 25 + .../speechbrain/speechbrain/nnet/CNN.py | 1571 +++++++++ .../speechbrain/speechbrain/nnet/RNN.py | 2171 ++++++++++++ .../speechbrain/speechbrain/nnet/__init__.py | 7 + .../speechbrain/nnet/activations.py | 171 + .../speechbrain/speechbrain/nnet/adapters.py | 389 +++ .../speechbrain/speechbrain/nnet/attention.py | 1440 ++++++++ .../speechbrain/nnet/autoencoders.py | 481 +++ .../nnet/complex_networks/__init__.py | 1 + .../nnet/complex_networks/c_CNN.py | 498 +++ .../nnet/complex_networks/c_RNN.py | 1295 ++++++++ .../nnet/complex_networks/c_linear.py | 124 + .../nnet/complex_networks/c_normalization.py | 745 +++++ .../nnet/complex_networks/c_ops.py | 355 ++ .../speechbrain/nnet/containers.py | 408 +++ .../speechbrain/speechbrain/nnet/diffusion.py | 676 ++++ .../speechbrain/speechbrain/nnet/dropout.py | 60 + .../speechbrain/speechbrain/nnet/embedding.py | 120 + .../speechbrain/nnet/hypermixing.py | 372 +++ .../speechbrain/speechbrain/nnet/linear.py | 91 + .../speechbrain/nnet/loss/__init__.py | 1 + .../speechbrain/nnet/loss/guidedattn_loss.py | 178 + .../speechbrain/nnet/loss/si_snr_loss.py | 66 + .../speechbrain/nnet/loss/stoi_loss.py | 226 ++ .../speechbrain/speechbrain/nnet/losses.py | 1990 +++++++++++ .../speechbrain/nnet/normalization.py | 668 ++++ .../speechbrain/speechbrain/nnet/pooling.py | 609 ++++ .../speechbrain/nnet/quantisers.py | 184 ++ .../nnet/quaternion_networks/__init__.py | 1 + .../nnet/quaternion_networks/q_CNN.py | 681 ++++ .../nnet/quaternion_networks/q_RNN.py | 1313 ++++++++ .../nnet/quaternion_networks/q_linear.py | 242 ++ .../quaternion_networks/q_normalization.py | 162 + .../nnet/quaternion_networks/q_ops.py | 886 +++++ .../nnet/quaternion_networks/q_pooling.py | 125 + .../speechbrain/nnet/schedulers.py | 1710 ++++++++++ .../speechbrain/nnet/transducer/__init__.py | 1 + .../nnet/transducer/transducer_joint.py | 102 + .../speechbrain/speechbrain/nnet/unet.py | 1842 +++++++++++ .../speechbrain/speechbrain/nnet/utils.py | 88 + .../speechbrain/speechbrain/processing/NMF.py | 198 ++ .../speechbrain/processing/PLDA_LDA.py | 1072 ++++++ .../speechbrain/processing/__init__.py | 1 + .../speechbrain/processing/decomposition.py | 441 +++ .../speechbrain/processing/diarization.py | 11 + .../speechbrain/processing/features.py | 1913 +++++++++++ .../speechbrain/processing/multi_mic.py | 1589 +++++++++ .../processing/signal_processing.py | 652 ++++ .../speechbrain/processing/vocal_features.py | 520 +++ .../speechbrain/tokenizers/SentencePiece.py | 575 ++++ .../speechbrain/tokenizers/__init__.py | 1 + .../tokenizers/discrete_SSL_tokenizer.py | 127 + .../speechbrain/speechbrain/utils/Accuracy.py | 103 + .../speechbrain/speechbrain/utils/DER.py | 152 + .../speechbrain/speechbrain/utils/EDER.py | 286 ++ .../speechbrain/speechbrain/utils/__init__.py | 7 + .../speechbrain/utils/_workarounds.py | 36 + .../speechbrain/speechbrain/utils/autocast.py | 252 ++ .../speechbrain/utils/bertscore.py | 351 ++ .../speechbrain/speechbrain/utils/bleu.py | 11 + .../speechbrain/utils/callchains.py | 85 + .../speechbrain/utils/checkpoints.py | 1384 ++++++++ .../speechbrain/utils/data_pipeline.py | 690 ++++ .../speechbrain/utils/data_utils.py | 1262 +++++++ .../speechbrain/speechbrain/utils/depgraph.py | 273 ++ .../speechbrain/utils/dictionaries.py | 122 + .../speechbrain/utils/distances.py | 50 + .../speechbrain/utils/distributed.py | 501 +++ .../utils/dynamic_chunk_training.py | 188 ++ .../speechbrain/utils/edit_distance.py | 797 +++++ .../speechbrain/utils/epoch_loop.py | 201 ++ .../speechbrain/speechbrain/utils/fetching.py | 436 +++ .../speechbrain/utils/filter_analysis.py | 226 ++ .../speechbrain/speechbrain/utils/hparams.py | 37 + .../speechbrain/speechbrain/utils/hpopt.py | 494 +++ .../speechbrain/utils/importutils.py | 309 ++ .../speechbrain/speechbrain/utils/kmeans.py | 229 ++ .../speechbrain/speechbrain/utils/logger.py | 320 ++ .../speechbrain/utils/metric_stats.py | 1425 ++++++++ .../speechbrain/utils/optimizers.py | 37 + .../speechbrain/speechbrain/utils/parallel.py | 346 ++ .../speechbrain/utils/parameter_transfer.py | 350 ++ .../speechbrain/utils/pretrained.py | 96 + .../speechbrain/utils/profiling.py | 40 + .../speechbrain/speechbrain/utils/quirks.py | 123 + .../speechbrain/speechbrain/utils/repro.py | 172 + .../speechbrain/speechbrain/utils/run_opts.py | 363 ++ .../speechbrain/speechbrain/utils/seed.py | 71 + .../speechbrain/speechbrain/utils/semdist.py | 197 ++ .../speechbrain/utils/streaming.py | 235 ++ .../speechbrain/utils/superpowers.py | 87 + .../speechbrain/utils/text_to_sequence.py | 388 +++ .../speechbrain/utils/torch_audio_backend.py | 107 + .../speechbrain/utils/train_logger.py | 484 +++ .../speechbrain/speechbrain/version.txt | 1 + .../local_libs/wenet/wenet/README.md | 29 + .../local_libs/wenet/wenet/__init__.py | 1 + .../local_libs/wenet/wenet/cli/__init__.py | 0 .../local_libs/wenet/wenet/cli/hub.py | 123 + .../local_libs/wenet/wenet/cli/model.py | 110 + .../local_libs/wenet/wenet/cli/punc_model.py | 116 + .../local_libs/wenet/wenet/cli/transcribe.py | 72 + .../wenet/wenet/dataset/__init__.py | 0 .../wenet/wenet/dataset/datapipes.py | 470 +++ .../local_libs/wenet/wenet/dataset/dataset.py | 155 + .../wenet/dataset/deprecated/__init__.py | 0 .../wenet/wenet/dataset/deprecated/dataset.py | 202 ++ .../wenet/dataset/deprecated/processor.py | 665 ++++ .../wenet/wenet/dataset/kaldi_io.py | 772 +++++ .../wenet/wenet/dataset/processor.py | 596 ++++ .../wenet/wenet/dataset/wav_distortion.py | 336 ++ .../local_libs/wenet/wenet/models/__init__.py | 0 .../wenet/models/branchformer/__init__.py | 0 .../wenet/wenet/models/branchformer/cgmlp.py | 194 ++ .../wenet/models/branchformer/encoder.py | 176 + .../models/branchformer/encoder_layer.py | 246 ++ .../wenet/wenet/models/ctl_model/__init__.py | 0 .../wenet/models/ctl_model/asr_model_ctl.py | 278 ++ .../wenet/wenet/models/ctl_model/encoder.py | 173 + .../wenet/models/e_branchformer/__init__.py | 0 .../wenet/models/e_branchformer/encoder.py | 164 + .../models/e_branchformer/encoder_layer.py | 188 ++ .../models/efficient_conformer/__init__.py | 0 .../models/efficient_conformer/attention.py | 258 ++ .../models/efficient_conformer/convolution.py | 154 + .../models/efficient_conformer/encoder.py | 557 ++++ .../efficient_conformer/encoder_layer.py | 165 + .../models/efficient_conformer/subsampling.py | 75 + .../wenet/wenet/models/finetune/__init__.py | 0 .../wenet/models/finetune/lora/__init__.py | 0 .../wenet/models/finetune/lora/config.yaml | 13 + .../wenet/models/finetune/lora/layers.py | 350 ++ .../wenet/wenet/models/finetune/lora/utils.py | 334 ++ .../wenet/wenet/models/firered/__init__.py | 0 .../wenet/wenet/models/firered/attention.py | 183 ++ ..._FireRed_AED_L_to_wenet_config_and_ckpt.py | 336 ++ .../wenet/wenet/models/firered/encoder.py | 129 + .../wenet/models/firered/encoder_layer.py | 43 + .../wenet/wenet/models/firered/model.py | 62 + .../wenet/wenet/models/firered/subsampling.py | 75 + .../wenet/wenet/models/k2/__init__.py | 0 .../local_libs/wenet/wenet/models/k2/model.py | 304 ++ .../wenet/wenet/models/paraformer/__init__.py | 0 .../wenet/models/paraformer/attention.py | 219 ++ .../wenet/wenet/models/paraformer/cif.py | 293 ++ ...ert_paraformer_to_wenet_config_and_ckpt.py | 329 ++ .../wenet/models/paraformer/embedding.py | 14 + .../wenet/wenet/models/paraformer/layers.py | 496 +++ .../wenet/models/paraformer/paraformer.py | 413 +++ .../wenet/wenet/models/paraformer/search.py | 256 ++ .../wenet/models/paraformer/subsampling.py | 50 + .../wenet/wenet/models/sensevoice/__init__.py | 0 ...nsevoice_small_to_wenet_config_and_ckpt.py | 170 + .../sensevoice/sensevoice_small_model.py | 290 ++ .../wenet/models/squeezeformer/__init__.py | 0 .../wenet/models/squeezeformer/attention.py | 234 ++ .../wenet/models/squeezeformer/conv2d.py | 66 + .../wenet/models/squeezeformer/convolution.py | 177 + .../wenet/models/squeezeformer/encoder.py | 469 +++ .../models/squeezeformer/encoder_layer.py | 121 + .../positionwise_feed_forward.py | 77 + .../wenet/models/squeezeformer/subsampling.py | 323 ++ .../wenet/wenet/models/ssl/__init__.py | 0 .../wenet/wenet/models/ssl/bestrq/__init__.py | 0 .../wenet/models/ssl/bestrq/bestrq_model.py | 298 ++ .../wenet/wenet/models/ssl/bestrq/mask.py | 160 + .../wenet/wenet/models/ssl/init_dataset.py | 157 + .../wenet/wenet/models/ssl/init_model.py | 19 + .../wenet/models/ssl/w2vbert/__init__.py | 0 ...onvert_w2vbert_to_wenet_config_and_ckpt.py | 194 ++ .../wenet/models/ssl/w2vbert/w2vbert_model.py | 320 ++ .../wenet/models/ssl/wav2vec2/__init__.py | 0 .../wenet/models/ssl/wav2vec2/quantizer.py | 113 + .../models/ssl/wav2vec2/wav2vec2_model.py | 325 ++ .../wenet/wenet/models/transducer/__init__.py | 0 .../wenet/wenet/models/transducer/joint.py | 106 + .../wenet/models/transducer/predictor.py | 495 +++ .../models/transducer/search/__init__.py | 0 .../models/transducer/search/greedy_search.py | 54 + .../transducer/search/prefix_beam_search.py | 148 + .../wenet/models/transducer/transducer.py | 572 ++++ .../wenet/models/transformer/__init__.py | 0 .../wenet/models/transformer/asr_model.py | 547 +++ .../wenet/models/transformer/attention.py | 686 ++++ .../wenet/wenet/models/transformer/cmvn.py | 47 + .../wenet/models/transformer/convolution.py | 153 + .../wenet/wenet/models/transformer/ctc.py | 91 + .../wenet/wenet/models/transformer/decoder.py | 494 +++ .../wenet/models/transformer/decoder_layer.py | 153 + .../wenet/models/transformer/embedding.py | 259 ++ .../wenet/wenet/models/transformer/encoder.py | 552 ++++ .../wenet/models/transformer/encoder_layer.py | 265 ++ .../transformer/label_smoothing_loss.py | 96 + .../wenet/wenet/models/transformer/norm.py | 27 + .../transformer/positionwise_feed_forward.py | 159 + .../wenet/wenet/models/transformer/search.py | 458 +++ .../wenet/models/transformer/subsampling.py | 394 +++ .../wenet/wenet/models/transformer/swish.py | 26 + .../wenet/wenet/models/whisper/__init__.py | 0 ...onvert_whisper_to_wenet_config_and_ckpt.py | 310 ++ .../wenet/wenet/models/whisper/whisper.py | 96 + .../local_libs/wenet/wenet/text/__init__.py | 0 .../wenet/wenet/text/base_tokenizer.py | 41 + .../wenet/wenet/text/bpe_tokenizer.py | 51 + .../wenet/wenet/text/char_tokenizer.py | 80 + .../wenet/text/hugging_face_tokenizer.py | 58 + .../wenet/wenet/text/paraformer_tokenizer.py | 53 + .../wenet/text/sentencepiece_tokenizer.py | 57 + .../wenet/wenet/text/tokenize_utils.py | 65 + .../wenet/wenet/text/whisper_tokenizer.py | 103 + .../local_libs/wenet/wenet/utils/__init__.py | 0 .../wenet/wenet/utils/checkpoint.py | 117 + .../wenet/wenet/utils/class_utils.py | 98 + .../local_libs/wenet/wenet/utils/cmvn.py | 93 + .../local_libs/wenet/wenet/utils/common.py | 377 +++ .../local_libs/wenet/wenet/utils/config.py | 39 + .../wenet/wenet/utils/context_graph.py | 265 ++ .../local_libs/wenet/wenet/utils/ctc_utils.py | 136 + .../local_libs/wenet/wenet/utils/executor.py | 161 + .../wenet/wenet/utils/file_utils.py | 69 + .../wenet/wenet/utils/fsdp_utils.py | 116 + .../wenet/wenet/utils/init_dataset.py | 42 + .../wenet/wenet/utils/init_model.py | 217 ++ .../wenet/wenet/utils/init_tokenizer.py | 58 + .../local_libs/wenet/wenet/utils/mask.py | 373 +++ .../wenet/wenet/utils/rope_utils.py | 39 + .../local_libs/wenet/wenet/utils/scheduler.py | 722 ++++ .../wenet/wenet/utils/train_utils.py | 930 ++++++ .../scripts/audio_convert/__init__.py | 3 + .../scripts/audio_convert/__main__.py | 7 + .../scripts/audio_convert/audio_convert.py | 377 +++ .../scripts/audio_convert/config_loader.py | 149 + .../scripts/audio_convert/readme.md | 29 + .../src/pipeline/0_normalization.py | 20 + .../src/pipeline/1_denoise.py | 77 + .../src/pipeline/2_anomaly_filter.py | 25 + .../src/pipeline/3_fast_lang_id.py | 57 + .../src/pipeline/4_split_and_tag.py | 20 + .../src/pipeline/5_recognize_monitor.py | 22 + .../src/pipeline/6_eval_wer.py | 25 + .../src/pipeline/7_eval_keyword_recall.py | 26 + .../src/pipeline/anomaly_filter.py | 327 ++ .../src/pipeline/eval_keyword_recall.py | 351 ++ .../src/pipeline/eval_wer.py | 300 ++ .../src/pipeline/merge_asr_by_source.py | 130 + .../src/pipeline/normalization.py | 352 ++ .../src/pipeline/recognize_monitor.py | 300 ++ .../src/pipeline/split_and_tag.py | 228 ++ .../src/tools/audio_anomaly_filter.py | 417 +++ .../src/tools/convert_audio.py | 678 ++++ .../src/tools/gtcrn_denoise.py | 150 + .../audio_preprocessor/src/tools/readme.txt | 1 + .../audio_preprocessor/src/tools/recognize.py | 189 ++ .../src/tools/split_audio.py | 109 + .../src/utils/color_utils.py | 67 + .../src/utils/compute_wer.py | 553 ++++ .../src/utils/fast_lang_id.py | 487 +++ .../src/utils/generate_audio_list.py | 261 ++ .../src/utils/gtcrn_denoise.py | 349 ++ .../audio_preprocessor/src/utils/recognize.py | 336 ++ .../audio_preprocessor/src/utils/run_wenet.py | 38 + .../src/utils/yaml_config_loader.py | 119 + .../mapper/audio_asr_pipeline/audio_skip.py | 114 + .../mapper/audio_asr_pipeline/metadata.yml | 155 + .../ops/mapper/audio_asr_pipeline/process.py | 558 ++++ .../audio_asr_pipeline/requirements.txt | 7 + .../ops/mapper/audio_asr_transcribe/README.md | 68 + .../mapper/audio_asr_transcribe/__init__.py | 6 + .../local_libs/wenet/wenet/README.md | 29 + .../local_libs/wenet/wenet/__init__.py | 1 + .../local_libs/wenet/wenet/cli/__init__.py | 0 .../local_libs/wenet/wenet/cli/hub.py | 123 + .../local_libs/wenet/wenet/cli/model.py | 110 + .../local_libs/wenet/wenet/cli/punc_model.py | 116 + .../local_libs/wenet/wenet/cli/transcribe.py | 72 + .../wenet/wenet/dataset/__init__.py | 0 .../wenet/wenet/dataset/datapipes.py | 470 +++ .../local_libs/wenet/wenet/dataset/dataset.py | 155 + .../wenet/dataset/deprecated/__init__.py | 0 .../wenet/wenet/dataset/deprecated/dataset.py | 202 ++ .../wenet/dataset/deprecated/processor.py | 665 ++++ .../wenet/wenet/dataset/kaldi_io.py | 772 +++++ .../wenet/wenet/dataset/processor.py | 596 ++++ .../wenet/wenet/dataset/wav_distortion.py | 336 ++ .../local_libs/wenet/wenet/models/__init__.py | 0 .../wenet/models/branchformer/__init__.py | 0 .../wenet/wenet/models/branchformer/cgmlp.py | 194 ++ .../wenet/models/branchformer/encoder.py | 176 + .../models/branchformer/encoder_layer.py | 246 ++ .../wenet/wenet/models/ctl_model/__init__.py | 0 .../wenet/models/ctl_model/asr_model_ctl.py | 278 ++ .../wenet/wenet/models/ctl_model/encoder.py | 173 + .../wenet/models/e_branchformer/__init__.py | 0 .../wenet/models/e_branchformer/encoder.py | 164 + .../models/e_branchformer/encoder_layer.py | 188 ++ .../models/efficient_conformer/__init__.py | 0 .../models/efficient_conformer/attention.py | 258 ++ .../models/efficient_conformer/convolution.py | 154 + .../models/efficient_conformer/encoder.py | 557 ++++ .../efficient_conformer/encoder_layer.py | 165 + .../models/efficient_conformer/subsampling.py | 75 + .../wenet/wenet/models/finetune/__init__.py | 0 .../wenet/models/finetune/lora/__init__.py | 0 .../wenet/models/finetune/lora/config.yaml | 13 + .../wenet/models/finetune/lora/layers.py | 350 ++ .../wenet/wenet/models/finetune/lora/utils.py | 334 ++ .../wenet/wenet/models/firered/__init__.py | 0 .../wenet/wenet/models/firered/attention.py | 183 ++ ..._FireRed_AED_L_to_wenet_config_and_ckpt.py | 336 ++ .../wenet/wenet/models/firered/encoder.py | 129 + .../wenet/models/firered/encoder_layer.py | 43 + .../wenet/wenet/models/firered/model.py | 62 + .../wenet/wenet/models/firered/subsampling.py | 75 + .../wenet/wenet/models/k2/__init__.py | 0 .../local_libs/wenet/wenet/models/k2/model.py | 304 ++ .../wenet/wenet/models/paraformer/__init__.py | 0 .../wenet/models/paraformer/attention.py | 219 ++ .../wenet/wenet/models/paraformer/cif.py | 293 ++ ...ert_paraformer_to_wenet_config_and_ckpt.py | 329 ++ .../wenet/models/paraformer/embedding.py | 14 + .../wenet/wenet/models/paraformer/layers.py | 496 +++ .../wenet/models/paraformer/paraformer.py | 413 +++ .../wenet/wenet/models/paraformer/search.py | 256 ++ .../wenet/models/paraformer/subsampling.py | 50 + .../wenet/wenet/models/sensevoice/__init__.py | 0 ...nsevoice_small_to_wenet_config_and_ckpt.py | 170 + .../sensevoice/sensevoice_small_model.py | 290 ++ .../wenet/models/squeezeformer/__init__.py | 0 .../wenet/models/squeezeformer/attention.py | 234 ++ .../wenet/models/squeezeformer/conv2d.py | 66 + .../wenet/models/squeezeformer/convolution.py | 177 + .../wenet/models/squeezeformer/encoder.py | 469 +++ .../models/squeezeformer/encoder_layer.py | 121 + .../positionwise_feed_forward.py | 77 + .../wenet/models/squeezeformer/subsampling.py | 323 ++ .../wenet/wenet/models/ssl/__init__.py | 0 .../wenet/wenet/models/ssl/bestrq/__init__.py | 0 .../wenet/models/ssl/bestrq/bestrq_model.py | 298 ++ .../wenet/wenet/models/ssl/bestrq/mask.py | 160 + .../wenet/wenet/models/ssl/init_dataset.py | 157 + .../wenet/wenet/models/ssl/init_model.py | 19 + .../wenet/models/ssl/w2vbert/__init__.py | 0 ...onvert_w2vbert_to_wenet_config_and_ckpt.py | 194 ++ .../wenet/models/ssl/w2vbert/w2vbert_model.py | 320 ++ .../wenet/models/ssl/wav2vec2/__init__.py | 0 .../wenet/models/ssl/wav2vec2/quantizer.py | 113 + .../models/ssl/wav2vec2/wav2vec2_model.py | 325 ++ .../wenet/wenet/models/transducer/__init__.py | 0 .../wenet/wenet/models/transducer/joint.py | 106 + .../wenet/models/transducer/predictor.py | 495 +++ .../models/transducer/search/__init__.py | 0 .../models/transducer/search/greedy_search.py | 54 + .../transducer/search/prefix_beam_search.py | 148 + .../wenet/models/transducer/transducer.py | 572 ++++ .../wenet/models/transformer/__init__.py | 0 .../wenet/models/transformer/asr_model.py | 547 +++ .../wenet/models/transformer/attention.py | 686 ++++ .../wenet/wenet/models/transformer/cmvn.py | 47 + .../wenet/models/transformer/convolution.py | 153 + .../wenet/wenet/models/transformer/ctc.py | 91 + .../wenet/wenet/models/transformer/decoder.py | 494 +++ .../wenet/models/transformer/decoder_layer.py | 153 + .../wenet/models/transformer/embedding.py | 259 ++ .../wenet/wenet/models/transformer/encoder.py | 552 ++++ .../wenet/models/transformer/encoder_layer.py | 265 ++ .../transformer/label_smoothing_loss.py | 96 + .../wenet/wenet/models/transformer/norm.py | 27 + .../transformer/positionwise_feed_forward.py | 159 + .../wenet/wenet/models/transformer/search.py | 458 +++ .../wenet/models/transformer/subsampling.py | 394 +++ .../wenet/wenet/models/transformer/swish.py | 26 + .../wenet/wenet/models/whisper/__init__.py | 0 ...onvert_whisper_to_wenet_config_and_ckpt.py | 310 ++ .../wenet/wenet/models/whisper/whisper.py | 96 + .../local_libs/wenet/wenet/text/__init__.py | 0 .../wenet/wenet/text/base_tokenizer.py | 41 + .../wenet/wenet/text/bpe_tokenizer.py | 51 + .../wenet/wenet/text/char_tokenizer.py | 80 + .../wenet/text/hugging_face_tokenizer.py | 58 + .../wenet/wenet/text/paraformer_tokenizer.py | 53 + .../wenet/text/sentencepiece_tokenizer.py | 57 + .../wenet/wenet/text/tokenize_utils.py | 65 + .../wenet/wenet/text/whisper_tokenizer.py | 103 + .../local_libs/wenet/wenet/utils/__init__.py | 0 .../wenet/wenet/utils/checkpoint.py | 117 + .../wenet/wenet/utils/class_utils.py | 98 + .../local_libs/wenet/wenet/utils/cmvn.py | 93 + .../local_libs/wenet/wenet/utils/common.py | 377 +++ .../local_libs/wenet/wenet/utils/config.py | 39 + .../wenet/wenet/utils/context_graph.py | 265 ++ .../local_libs/wenet/wenet/utils/ctc_utils.py | 136 + .../local_libs/wenet/wenet/utils/executor.py | 161 + .../wenet/wenet/utils/file_utils.py | 69 + .../wenet/wenet/utils/fsdp_utils.py | 116 + .../wenet/wenet/utils/init_dataset.py | 42 + .../wenet/wenet/utils/init_model.py | 217 ++ .../wenet/wenet/utils/init_tokenizer.py | 58 + .../local_libs/wenet/wenet/utils/mask.py | 373 +++ .../wenet/wenet/utils/rope_utils.py | 39 + .../local_libs/wenet/wenet/utils/scheduler.py | 722 ++++ .../wenet/wenet/utils/train_utils.py | 930 ++++++ .../audio_preprocessor/src/utils/run_wenet.py | 38 + .../mapper/audio_asr_transcribe/audio_skip.py | 114 + .../mapper/audio_asr_transcribe/metadata.yml | 108 + .../mapper/audio_asr_transcribe/process.py | 488 +++ .../audio_asr_transcribe/requirements.txt | 6 + .../mapper/audio_dc_offset_removal/README.md | 24 + .../audio_dc_offset_removal/__init__.py | 6 + .../audio_dc_offset_removal/audio_skip.py | 114 + .../audio_dc_offset_removal/metadata.yml | 26 + .../mapper/audio_dc_offset_removal/process.py | 97 + .../audio_dc_offset_removal/requirements.txt | 2 + .../mapper/audio_emotion_recognize/README.md | 34 + .../audio_emotion_recognize/__init__.py | 6 + .../audio_emotion_recognize/audio_skip.py | 119 + .../helpers/utils/emotion_small_model.py | 139 + .../audio_emotion_recognize/metadata.yml | 71 + .../mapper/audio_emotion_recognize/process.py | 345 ++ .../audio_emotion_recognize/requirements.txt | 8 + .../ops/mapper/audio_fast_lang_id/README.md | 40 + .../ops/mapper/audio_fast_lang_id/__init__.py | 6 + .../mapper/audio_fast_lang_id/audio_skip.py | 114 + .../helpers/utils/color_utils.py | 67 + .../helpers/utils/compute_wer.py | 553 ++++ .../helpers/utils/fast_lang_id.py | 487 +++ .../helpers/utils/generate_audio_list.py | 261 ++ .../helpers/utils/gtcrn_denoise.py | 349 ++ .../helpers/utils/yaml_config_loader.py | 119 + .../speechbrain/speechbrain/__init__.py | 71 + .../speechbrain/alignment/__init__.py | 1 + .../speechbrain/alignment/aligner.py | 1494 +++++++++ .../speechbrain/alignment/ctc_segmentation.py | 11 + .../speechbrain/augment/__init__.py | 1 + .../speechbrain/augment/augmenter.py | 544 +++ .../speechbrain/speechbrain/augment/codec.py | 92 + .../speechbrain/augment/freq_domain.py | 399 +++ .../speechbrain/augment/preparation.py | 219 ++ .../speechbrain/augment/time_domain.py | 1540 +++++++++ .../speechbrain/speechbrain/core.py | 1489 +++++++++ .../speechbrain/dataio/__init__.py | 5 + .../speechbrain/dataio/audio_io.py | 228 ++ .../speechbrain/speechbrain/dataio/batch.py | 333 ++ .../speechbrain/speechbrain/dataio/dataio.py | 1417 ++++++++ .../speechbrain/dataio/dataloader.py | 420 +++ .../speechbrain/speechbrain/dataio/dataset.py | 546 +++ .../speechbrain/speechbrain/dataio/encoder.py | 1216 +++++++ .../speechbrain/dataio/iterators.py | 235 ++ .../speechbrain/speechbrain/dataio/legacy.py | 321 ++ .../speechbrain/dataio/preprocess.py | 82 + .../speechbrain/speechbrain/dataio/sampler.py | 845 +++++ .../speechbrain/speechbrain/dataio/wer.py | 201 ++ .../speechbrain/decoders/__init__.py | 6 + .../speechbrain/speechbrain/decoders/ctc.py | 1905 +++++++++++ .../speechbrain/decoders/language_model.py | 11 + .../speechbrain/decoders/scorer.py | 2189 ++++++++++++ .../speechbrain/decoders/seq2seq.py | 2240 +++++++++++++ .../speechbrain/decoders/transducer.py | 648 ++++ .../speechbrain/speechbrain/decoders/utils.py | 158 + .../speechbrain/speechbrain/inference/ASR.py | 1546 +++++++++ .../speechbrain/speechbrain/inference/SLU.py | 144 + .../speechbrain/speechbrain/inference/ST.py | 138 + .../speechbrain/speechbrain/inference/TTS.py | 928 ++++++ .../speechbrain/speechbrain/inference/VAD.py | 965 ++++++ .../speechbrain/inference/__init__.py | 17 + .../speechbrain/inference/classifiers.py | 322 ++ .../speechbrain/inference/diarization.py | 241 ++ .../speechbrain/inference/encoders.py | 272 ++ .../speechbrain/inference/enhancement.py | 373 +++ .../speechbrain/inference/interfaces.py | 694 ++++ .../speechbrain/inference/interpretability.py | 182 + .../speechbrain/inference/metrics.py | 97 + .../speechbrain/inference/separation.py | 129 + .../speechbrain/inference/speaker.py | 133 + .../speechbrain/speechbrain/inference/text.py | 443 +++ .../speechbrain/inference/vocoders.py | 399 +++ .../speechbrain/integrations/README.md | 33 + .../speechbrain/integrations/__init__.py | 7 + .../integrations/alignment/README.md | 31 + .../integrations/alignment/__init__.py | 3 + .../integrations/alignment/ctc_seg.py | 675 ++++ .../integrations/alignment/diarization.py | 1231 +++++++ .../integrations/audio_tokenizers/README.md | 45 + .../integrations/audio_tokenizers/__init__.py | 3 + .../audio_tokenizers/discrete_ssl.py | 408 +++ .../integrations/audio_tokenizers/kmeans.py | 178 + .../speechtokenizer_interface.py | 157 + .../wavtokenizer_interface.py | 168 + .../integrations/decoders/README.md | 30 + .../integrations/decoders/__init__.py | 3 + .../integrations/decoders/kenlm_scorer.py | 321 ++ .../speechbrain/integrations/hdf5/README.md | 30 + .../speechbrain/integrations/hdf5/__init__.py | 7 + .../integrations/hdf5/cached_item.py | 159 + .../integrations/huggingface/README.md | 70 + .../integrations/huggingface/__init__.py | 20 + .../integrations/huggingface/encodec.py | 385 +++ .../integrations/huggingface/gpt.py | 179 + .../integrations/huggingface/hubert.py | 88 + .../integrations/huggingface/huggingface.py | 455 +++ .../integrations/huggingface/labse.py | 116 + .../integrations/huggingface/llama.py | 198 ++ .../integrations/huggingface/mbart.py | 221 ++ .../integrations/huggingface/mert.py | 88 + .../integrations/huggingface/mimi.py | 191 ++ .../integrations/huggingface/nllb.py | 75 + .../integrations/huggingface/textencoder.py | 122 + .../integrations/huggingface/vocos.py | 158 + .../integrations/huggingface/w2v_bert.py | 200 ++ .../integrations/huggingface/wav2vec2.py | 332 ++ .../integrations/huggingface/wavlm.py | 88 + .../integrations/huggingface/weighted_ssl.py | 122 + .../integrations/huggingface/whisper.py | 637 ++++ .../huggingface/wordemb/__init__.py | 1 + .../huggingface/wordemb/transformer.py | 289 ++ .../integrations/huggingface/wordemb/util.py | 72 + .../speechbrain/integrations/k2_fsa/README.md | 38 + .../integrations/k2_fsa/__init__.py | 20 + .../speechbrain/integrations/k2_fsa/align.py | 667 ++++ .../integrations/k2_fsa/graph_compiler.py | 387 +++ .../integrations/k2_fsa/lattice_decoder.py | 453 +++ .../integrations/k2_fsa/lexicon.py | 584 ++++ .../speechbrain/integrations/k2_fsa/losses.py | 134 + .../integrations/k2_fsa/prepare_lang.py | 575 ++++ .../speechbrain/integrations/k2_fsa/utils.py | 168 + .../speechbrain/integrations/models/README.md | 28 + .../integrations/models/__init__.py | 3 + .../integrations/models/sgmse_plus.py | 615 ++++ .../speechbrain/integrations/nlp/README.md | 36 + .../speechbrain/integrations/nlp/__init__.py | 5 + .../integrations/nlp/bgeM3_embeddings.py | 180 + .../speechbrain/integrations/nlp/bleu.py | 105 + .../integrations/nlp/flair_embeddings.py | 150 + .../integrations/nlp/flair_tagger.py | 87 + .../integrations/nlp/spacy_pipeline.py | 144 + .../speechbrain/integrations/numba/README.md | 25 + .../integrations/numba/__init__.py | 18 + .../integrations/numba/transducer_loss.py | 354 ++ .../integrations/tests/test_cached_item.py | 506 +++ .../tests/test_ctc_segmentation.py | 85 + .../speechbrain/integrations/tests/test_k2.py | 458 +++ .../integrations/tests/test_nlp.py | 78 + .../speechbrain/speechbrain/lm/__init__.py | 1 + .../speechbrain/speechbrain/lm/arpa.py | 353 ++ .../speechbrain/speechbrain/lm/counting.py | 166 + .../speechbrain/speechbrain/lm/ngram.py | 210 ++ .../speechbrain/speechbrain/lobes/__init__.py | 9 + .../speechbrain/lobes/beamform_multimic.py | 50 + .../speechbrain/lobes/downsampling.py | 176 + .../speechbrain/speechbrain/lobes/features.py | 862 +++++ .../speechbrain/lobes/models/BESTRQ.py | 128 + .../speechbrain/lobes/models/CRDNN.py | 315 ++ .../speechbrain/lobes/models/Cnn14.py | 422 +++ .../speechbrain/lobes/models/ContextNet.py | 304 ++ .../speechbrain/lobes/models/DiffWave.py | 701 ++++ .../speechbrain/lobes/models/ECAPA_TDNN.py | 636 ++++ .../speechbrain/lobes/models/ESPnetVGG.py | 128 + .../speechbrain/lobes/models/EnhanceResnet.py | 251 ++ .../speechbrain/lobes/models/FastSpeech2.py | 2924 +++++++++++++++++ .../speechbrain/lobes/models/GatedNN.py | 135 + .../speechbrain/lobes/models/HifiGAN.py | 1838 +++++++++++ .../speechbrain/lobes/models/L2I.py | 581 ++++ .../speechbrain/lobes/models/MSTacotron2.py | 754 +++++ .../speechbrain/lobes/models/MetricGAN.py | 195 ++ .../speechbrain/lobes/models/MetricGAN_U.py | 193 ++ .../speechbrain/lobes/models/PIQ.py | 699 ++++ .../speechbrain/lobes/models/RNNLM.py | 124 + .../speechbrain/lobes/models/ResNet.py | 520 +++ .../speechbrain/lobes/models/Tacotron2.py | 1886 +++++++++++ .../speechbrain/lobes/models/VanillaNN.py | 51 + .../speechbrain/lobes/models/Xvector.py | 246 ++ .../speechbrain/lobes/models/__init__.py | 1 + .../speechbrain/lobes/models/beats.py | 2096 ++++++++++++ .../speechbrain/lobes/models/bsq.py | 181 + .../speechbrain/lobes/models/conv_tasnet.py | 622 ++++ .../speechbrain/lobes/models/convolution.py | 320 ++ .../lobes/models/discrete/__init__.py | 6 + .../speechbrain/lobes/models/discrete/dac.py | 1122 +++++++ .../speechbrain/lobes/models/dual_path.py | 1494 +++++++++ .../lobes/models/fairseq_wav2vec.py | 362 ++ .../speechbrain/lobes/models/g2p/__init__.py | 5 + .../speechbrain/lobes/models/g2p/dataio.py | 688 ++++ .../speechbrain/lobes/models/g2p/homograph.py | 681 ++++ .../speechbrain/lobes/models/g2p/model.py | 582 ++++ .../speechbrain/lobes/models/kmeans.py | 11 + .../speechbrain/lobes/models/resepformer.py | 781 +++++ .../speechbrain/lobes/models/segan_model.py | 253 ++ .../lobes/models/transformer/Branchformer.py | 409 +++ .../lobes/models/transformer/Conformer.py | 1153 +++++++ .../lobes/models/transformer/Transformer.py | 1100 +++++++ .../models/transformer/TransformerASR.py | 726 ++++ .../lobes/models/transformer/TransformerLM.py | 187 ++ .../lobes/models/transformer/TransformerSE.py | 104 + .../lobes/models/transformer/TransformerST.py | 437 +++ .../lobes/models/transformer/__init__.py | 5 + .../speechbrain/lobes/models/wav2vec.py | 413 +++ .../speechbrain/speechbrain/log-config.yaml | 25 + .../speechbrain/speechbrain/nnet/CNN.py | 1571 +++++++++ .../speechbrain/speechbrain/nnet/RNN.py | 2171 ++++++++++++ .../speechbrain/speechbrain/nnet/__init__.py | 7 + .../speechbrain/nnet/activations.py | 171 + .../speechbrain/speechbrain/nnet/adapters.py | 389 +++ .../speechbrain/speechbrain/nnet/attention.py | 1440 ++++++++ .../speechbrain/nnet/autoencoders.py | 481 +++ .../nnet/complex_networks/__init__.py | 1 + .../nnet/complex_networks/c_CNN.py | 498 +++ .../nnet/complex_networks/c_RNN.py | 1295 ++++++++ .../nnet/complex_networks/c_linear.py | 124 + .../nnet/complex_networks/c_normalization.py | 745 +++++ .../nnet/complex_networks/c_ops.py | 355 ++ .../speechbrain/nnet/containers.py | 408 +++ .../speechbrain/speechbrain/nnet/diffusion.py | 676 ++++ .../speechbrain/speechbrain/nnet/dropout.py | 60 + .../speechbrain/speechbrain/nnet/embedding.py | 120 + .../speechbrain/nnet/hypermixing.py | 372 +++ .../speechbrain/speechbrain/nnet/linear.py | 91 + .../speechbrain/nnet/loss/__init__.py | 1 + .../speechbrain/nnet/loss/guidedattn_loss.py | 178 + .../speechbrain/nnet/loss/si_snr_loss.py | 66 + .../speechbrain/nnet/loss/stoi_loss.py | 226 ++ .../speechbrain/speechbrain/nnet/losses.py | 1990 +++++++++++ .../speechbrain/nnet/normalization.py | 668 ++++ .../speechbrain/speechbrain/nnet/pooling.py | 609 ++++ .../speechbrain/nnet/quantisers.py | 184 ++ .../nnet/quaternion_networks/__init__.py | 1 + .../nnet/quaternion_networks/q_CNN.py | 681 ++++ .../nnet/quaternion_networks/q_RNN.py | 1313 ++++++++ .../nnet/quaternion_networks/q_linear.py | 242 ++ .../quaternion_networks/q_normalization.py | 162 + .../nnet/quaternion_networks/q_ops.py | 886 +++++ .../nnet/quaternion_networks/q_pooling.py | 125 + .../speechbrain/nnet/schedulers.py | 1710 ++++++++++ .../speechbrain/nnet/transducer/__init__.py | 1 + .../nnet/transducer/transducer_joint.py | 102 + .../speechbrain/speechbrain/nnet/unet.py | 1842 +++++++++++ .../speechbrain/speechbrain/nnet/utils.py | 88 + .../speechbrain/speechbrain/processing/NMF.py | 198 ++ .../speechbrain/processing/PLDA_LDA.py | 1072 ++++++ .../speechbrain/processing/__init__.py | 1 + .../speechbrain/processing/decomposition.py | 441 +++ .../speechbrain/processing/diarization.py | 11 + .../speechbrain/processing/features.py | 1913 +++++++++++ .../speechbrain/processing/multi_mic.py | 1589 +++++++++ .../processing/signal_processing.py | 652 ++++ .../speechbrain/processing/vocal_features.py | 520 +++ .../speechbrain/tokenizers/SentencePiece.py | 575 ++++ .../speechbrain/tokenizers/__init__.py | 1 + .../tokenizers/discrete_SSL_tokenizer.py | 127 + .../speechbrain/speechbrain/utils/Accuracy.py | 103 + .../speechbrain/speechbrain/utils/DER.py | 152 + .../speechbrain/speechbrain/utils/EDER.py | 286 ++ .../speechbrain/speechbrain/utils/__init__.py | 7 + .../speechbrain/utils/_workarounds.py | 36 + .../speechbrain/speechbrain/utils/autocast.py | 252 ++ .../speechbrain/utils/bertscore.py | 351 ++ .../speechbrain/speechbrain/utils/bleu.py | 11 + .../speechbrain/utils/callchains.py | 85 + .../speechbrain/utils/checkpoints.py | 1384 ++++++++ .../speechbrain/utils/data_pipeline.py | 690 ++++ .../speechbrain/utils/data_utils.py | 1262 +++++++ .../speechbrain/speechbrain/utils/depgraph.py | 273 ++ .../speechbrain/utils/dictionaries.py | 122 + .../speechbrain/utils/distances.py | 50 + .../speechbrain/utils/distributed.py | 501 +++ .../utils/dynamic_chunk_training.py | 188 ++ .../speechbrain/utils/edit_distance.py | 797 +++++ .../speechbrain/utils/epoch_loop.py | 201 ++ .../speechbrain/speechbrain/utils/fetching.py | 436 +++ .../speechbrain/utils/filter_analysis.py | 226 ++ .../speechbrain/speechbrain/utils/hparams.py | 37 + .../speechbrain/speechbrain/utils/hpopt.py | 494 +++ .../speechbrain/utils/importutils.py | 309 ++ .../speechbrain/speechbrain/utils/kmeans.py | 229 ++ .../speechbrain/speechbrain/utils/logger.py | 320 ++ .../speechbrain/utils/metric_stats.py | 1425 ++++++++ .../speechbrain/utils/optimizers.py | 37 + .../speechbrain/speechbrain/utils/parallel.py | 346 ++ .../speechbrain/utils/parameter_transfer.py | 350 ++ .../speechbrain/utils/pretrained.py | 96 + .../speechbrain/utils/profiling.py | 40 + .../speechbrain/speechbrain/utils/quirks.py | 123 + .../speechbrain/speechbrain/utils/repro.py | 172 + .../speechbrain/speechbrain/utils/run_opts.py | 363 ++ .../speechbrain/speechbrain/utils/seed.py | 71 + .../speechbrain/speechbrain/utils/semdist.py | 197 ++ .../speechbrain/utils/streaming.py | 235 ++ .../speechbrain/utils/superpowers.py | 87 + .../speechbrain/utils/text_to_sequence.py | 388 +++ .../speechbrain/utils/torch_audio_backend.py | 107 + .../speechbrain/utils/train_logger.py | 484 +++ .../speechbrain/speechbrain/version.txt | 1 + .../mapper/audio_fast_lang_id/metadata.yml | 67 + .../ops/mapper/audio_fast_lang_id/process.py | 178 + .../audio_fast_lang_id/requirements.txt | 3 + .../mapper/audio_fast_lang_id_text/README.md | 38 + .../audio_fast_lang_id_text/__init__.py | 6 + .../audio_fast_lang_id_text/audio_skip.py | 114 + .../helpers/utils/color_utils.py | 67 + .../helpers/utils/compute_wer.py | 553 ++++ .../helpers/utils/fast_lang_id.py | 487 +++ .../helpers/utils/generate_audio_list.py | 261 ++ .../helpers/utils/gtcrn_denoise.py | 349 ++ .../helpers/utils/yaml_config_loader.py | 119 + .../speechbrain/speechbrain/__init__.py | 71 + .../speechbrain/alignment/__init__.py | 1 + .../speechbrain/alignment/aligner.py | 1494 +++++++++ .../speechbrain/alignment/ctc_segmentation.py | 11 + .../speechbrain/augment/__init__.py | 1 + .../speechbrain/augment/augmenter.py | 544 +++ .../speechbrain/speechbrain/augment/codec.py | 92 + .../speechbrain/augment/freq_domain.py | 399 +++ .../speechbrain/augment/preparation.py | 219 ++ .../speechbrain/augment/time_domain.py | 1540 +++++++++ .../speechbrain/speechbrain/core.py | 1489 +++++++++ .../speechbrain/dataio/__init__.py | 5 + .../speechbrain/dataio/audio_io.py | 228 ++ .../speechbrain/speechbrain/dataio/batch.py | 333 ++ .../speechbrain/speechbrain/dataio/dataio.py | 1417 ++++++++ .../speechbrain/dataio/dataloader.py | 420 +++ .../speechbrain/speechbrain/dataio/dataset.py | 546 +++ .../speechbrain/speechbrain/dataio/encoder.py | 1216 +++++++ .../speechbrain/dataio/iterators.py | 235 ++ .../speechbrain/speechbrain/dataio/legacy.py | 321 ++ .../speechbrain/dataio/preprocess.py | 82 + .../speechbrain/speechbrain/dataio/sampler.py | 845 +++++ .../speechbrain/speechbrain/dataio/wer.py | 201 ++ .../speechbrain/decoders/__init__.py | 6 + .../speechbrain/speechbrain/decoders/ctc.py | 1905 +++++++++++ .../speechbrain/decoders/language_model.py | 11 + .../speechbrain/decoders/scorer.py | 2189 ++++++++++++ .../speechbrain/decoders/seq2seq.py | 2240 +++++++++++++ .../speechbrain/decoders/transducer.py | 648 ++++ .../speechbrain/speechbrain/decoders/utils.py | 158 + .../speechbrain/speechbrain/inference/ASR.py | 1546 +++++++++ .../speechbrain/speechbrain/inference/SLU.py | 144 + .../speechbrain/speechbrain/inference/ST.py | 138 + .../speechbrain/speechbrain/inference/TTS.py | 928 ++++++ .../speechbrain/speechbrain/inference/VAD.py | 965 ++++++ .../speechbrain/inference/__init__.py | 17 + .../speechbrain/inference/classifiers.py | 322 ++ .../speechbrain/inference/diarization.py | 241 ++ .../speechbrain/inference/encoders.py | 272 ++ .../speechbrain/inference/enhancement.py | 373 +++ .../speechbrain/inference/interfaces.py | 694 ++++ .../speechbrain/inference/interpretability.py | 182 + .../speechbrain/inference/metrics.py | 97 + .../speechbrain/inference/separation.py | 129 + .../speechbrain/inference/speaker.py | 133 + .../speechbrain/speechbrain/inference/text.py | 443 +++ .../speechbrain/inference/vocoders.py | 399 +++ .../speechbrain/integrations/README.md | 33 + .../speechbrain/integrations/__init__.py | 7 + .../integrations/alignment/README.md | 31 + .../integrations/alignment/__init__.py | 3 + .../integrations/alignment/ctc_seg.py | 675 ++++ .../integrations/alignment/diarization.py | 1231 +++++++ .../integrations/audio_tokenizers/README.md | 45 + .../integrations/audio_tokenizers/__init__.py | 3 + .../audio_tokenizers/discrete_ssl.py | 408 +++ .../integrations/audio_tokenizers/kmeans.py | 178 + .../speechtokenizer_interface.py | 157 + .../wavtokenizer_interface.py | 168 + .../integrations/decoders/README.md | 30 + .../integrations/decoders/__init__.py | 3 + .../integrations/decoders/kenlm_scorer.py | 321 ++ .../speechbrain/integrations/hdf5/README.md | 30 + .../speechbrain/integrations/hdf5/__init__.py | 7 + .../integrations/hdf5/cached_item.py | 159 + .../integrations/huggingface/README.md | 70 + .../integrations/huggingface/__init__.py | 20 + .../integrations/huggingface/encodec.py | 385 +++ .../integrations/huggingface/gpt.py | 179 + .../integrations/huggingface/hubert.py | 88 + .../integrations/huggingface/huggingface.py | 455 +++ .../integrations/huggingface/labse.py | 116 + .../integrations/huggingface/llama.py | 198 ++ .../integrations/huggingface/mbart.py | 221 ++ .../integrations/huggingface/mert.py | 88 + .../integrations/huggingface/mimi.py | 191 ++ .../integrations/huggingface/nllb.py | 75 + .../integrations/huggingface/textencoder.py | 122 + .../integrations/huggingface/vocos.py | 158 + .../integrations/huggingface/w2v_bert.py | 200 ++ .../integrations/huggingface/wav2vec2.py | 332 ++ .../integrations/huggingface/wavlm.py | 88 + .../integrations/huggingface/weighted_ssl.py | 122 + .../integrations/huggingface/whisper.py | 637 ++++ .../huggingface/wordemb/__init__.py | 1 + .../huggingface/wordemb/transformer.py | 289 ++ .../integrations/huggingface/wordemb/util.py | 72 + .../speechbrain/integrations/k2_fsa/README.md | 38 + .../integrations/k2_fsa/__init__.py | 20 + .../speechbrain/integrations/k2_fsa/align.py | 667 ++++ .../integrations/k2_fsa/graph_compiler.py | 387 +++ .../integrations/k2_fsa/lattice_decoder.py | 453 +++ .../integrations/k2_fsa/lexicon.py | 584 ++++ .../speechbrain/integrations/k2_fsa/losses.py | 134 + .../integrations/k2_fsa/prepare_lang.py | 575 ++++ .../speechbrain/integrations/k2_fsa/utils.py | 168 + .../speechbrain/integrations/models/README.md | 28 + .../integrations/models/__init__.py | 3 + .../integrations/models/sgmse_plus.py | 615 ++++ .../speechbrain/integrations/nlp/README.md | 36 + .../speechbrain/integrations/nlp/__init__.py | 5 + .../integrations/nlp/bgeM3_embeddings.py | 180 + .../speechbrain/integrations/nlp/bleu.py | 105 + .../integrations/nlp/flair_embeddings.py | 150 + .../integrations/nlp/flair_tagger.py | 87 + .../integrations/nlp/spacy_pipeline.py | 144 + .../speechbrain/integrations/numba/README.md | 25 + .../integrations/numba/__init__.py | 18 + .../integrations/numba/transducer_loss.py | 354 ++ .../integrations/tests/test_cached_item.py | 506 +++ .../tests/test_ctc_segmentation.py | 85 + .../speechbrain/integrations/tests/test_k2.py | 458 +++ .../integrations/tests/test_nlp.py | 78 + .../speechbrain/speechbrain/lm/__init__.py | 1 + .../speechbrain/speechbrain/lm/arpa.py | 353 ++ .../speechbrain/speechbrain/lm/counting.py | 166 + .../speechbrain/speechbrain/lm/ngram.py | 210 ++ .../speechbrain/speechbrain/lobes/__init__.py | 9 + .../speechbrain/lobes/beamform_multimic.py | 50 + .../speechbrain/lobes/downsampling.py | 176 + .../speechbrain/speechbrain/lobes/features.py | 862 +++++ .../speechbrain/lobes/models/BESTRQ.py | 128 + .../speechbrain/lobes/models/CRDNN.py | 315 ++ .../speechbrain/lobes/models/Cnn14.py | 422 +++ .../speechbrain/lobes/models/ContextNet.py | 304 ++ .../speechbrain/lobes/models/DiffWave.py | 701 ++++ .../speechbrain/lobes/models/ECAPA_TDNN.py | 636 ++++ .../speechbrain/lobes/models/ESPnetVGG.py | 128 + .../speechbrain/lobes/models/EnhanceResnet.py | 251 ++ .../speechbrain/lobes/models/FastSpeech2.py | 2924 +++++++++++++++++ .../speechbrain/lobes/models/GatedNN.py | 135 + .../speechbrain/lobes/models/HifiGAN.py | 1838 +++++++++++ .../speechbrain/lobes/models/L2I.py | 581 ++++ .../speechbrain/lobes/models/MSTacotron2.py | 754 +++++ .../speechbrain/lobes/models/MetricGAN.py | 195 ++ .../speechbrain/lobes/models/MetricGAN_U.py | 193 ++ .../speechbrain/lobes/models/PIQ.py | 699 ++++ .../speechbrain/lobes/models/RNNLM.py | 124 + .../speechbrain/lobes/models/ResNet.py | 520 +++ .../speechbrain/lobes/models/Tacotron2.py | 1886 +++++++++++ .../speechbrain/lobes/models/VanillaNN.py | 51 + .../speechbrain/lobes/models/Xvector.py | 246 ++ .../speechbrain/lobes/models/__init__.py | 1 + .../speechbrain/lobes/models/beats.py | 2096 ++++++++++++ .../speechbrain/lobes/models/bsq.py | 181 + .../speechbrain/lobes/models/conv_tasnet.py | 622 ++++ .../speechbrain/lobes/models/convolution.py | 320 ++ .../lobes/models/discrete/__init__.py | 6 + .../speechbrain/lobes/models/discrete/dac.py | 1122 +++++++ .../speechbrain/lobes/models/dual_path.py | 1494 +++++++++ .../lobes/models/fairseq_wav2vec.py | 362 ++ .../speechbrain/lobes/models/g2p/__init__.py | 5 + .../speechbrain/lobes/models/g2p/dataio.py | 688 ++++ .../speechbrain/lobes/models/g2p/homograph.py | 681 ++++ .../speechbrain/lobes/models/g2p/model.py | 582 ++++ .../speechbrain/lobes/models/kmeans.py | 11 + .../speechbrain/lobes/models/resepformer.py | 781 +++++ .../speechbrain/lobes/models/segan_model.py | 253 ++ .../lobes/models/transformer/Branchformer.py | 409 +++ .../lobes/models/transformer/Conformer.py | 1153 +++++++ .../lobes/models/transformer/Transformer.py | 1100 +++++++ .../models/transformer/TransformerASR.py | 726 ++++ .../lobes/models/transformer/TransformerLM.py | 187 ++ .../lobes/models/transformer/TransformerSE.py | 104 + .../lobes/models/transformer/TransformerST.py | 437 +++ .../lobes/models/transformer/__init__.py | 5 + .../speechbrain/lobes/models/wav2vec.py | 413 +++ .../speechbrain/speechbrain/log-config.yaml | 25 + .../speechbrain/speechbrain/nnet/CNN.py | 1571 +++++++++ .../speechbrain/speechbrain/nnet/RNN.py | 2171 ++++++++++++ .../speechbrain/speechbrain/nnet/__init__.py | 7 + .../speechbrain/nnet/activations.py | 171 + .../speechbrain/speechbrain/nnet/adapters.py | 389 +++ .../speechbrain/speechbrain/nnet/attention.py | 1440 ++++++++ .../speechbrain/nnet/autoencoders.py | 481 +++ .../nnet/complex_networks/__init__.py | 1 + .../nnet/complex_networks/c_CNN.py | 498 +++ .../nnet/complex_networks/c_RNN.py | 1295 ++++++++ .../nnet/complex_networks/c_linear.py | 124 + .../nnet/complex_networks/c_normalization.py | 745 +++++ .../nnet/complex_networks/c_ops.py | 355 ++ .../speechbrain/nnet/containers.py | 408 +++ .../speechbrain/speechbrain/nnet/diffusion.py | 676 ++++ .../speechbrain/speechbrain/nnet/dropout.py | 60 + .../speechbrain/speechbrain/nnet/embedding.py | 120 + .../speechbrain/nnet/hypermixing.py | 372 +++ .../speechbrain/speechbrain/nnet/linear.py | 91 + .../speechbrain/nnet/loss/__init__.py | 1 + .../speechbrain/nnet/loss/guidedattn_loss.py | 178 + .../speechbrain/nnet/loss/si_snr_loss.py | 66 + .../speechbrain/nnet/loss/stoi_loss.py | 226 ++ .../speechbrain/speechbrain/nnet/losses.py | 1990 +++++++++++ .../speechbrain/nnet/normalization.py | 668 ++++ .../speechbrain/speechbrain/nnet/pooling.py | 609 ++++ .../speechbrain/nnet/quantisers.py | 184 ++ .../nnet/quaternion_networks/__init__.py | 1 + .../nnet/quaternion_networks/q_CNN.py | 681 ++++ .../nnet/quaternion_networks/q_RNN.py | 1313 ++++++++ .../nnet/quaternion_networks/q_linear.py | 242 ++ .../quaternion_networks/q_normalization.py | 162 + .../nnet/quaternion_networks/q_ops.py | 886 +++++ .../nnet/quaternion_networks/q_pooling.py | 125 + .../speechbrain/nnet/schedulers.py | 1710 ++++++++++ .../speechbrain/nnet/transducer/__init__.py | 1 + .../nnet/transducer/transducer_joint.py | 102 + .../speechbrain/speechbrain/nnet/unet.py | 1842 +++++++++++ .../speechbrain/speechbrain/nnet/utils.py | 88 + .../speechbrain/speechbrain/processing/NMF.py | 198 ++ .../speechbrain/processing/PLDA_LDA.py | 1072 ++++++ .../speechbrain/processing/__init__.py | 1 + .../speechbrain/processing/decomposition.py | 441 +++ .../speechbrain/processing/diarization.py | 11 + .../speechbrain/processing/features.py | 1913 +++++++++++ .../speechbrain/processing/multi_mic.py | 1589 +++++++++ .../processing/signal_processing.py | 652 ++++ .../speechbrain/processing/vocal_features.py | 520 +++ .../speechbrain/tokenizers/SentencePiece.py | 575 ++++ .../speechbrain/tokenizers/__init__.py | 1 + .../tokenizers/discrete_SSL_tokenizer.py | 127 + .../speechbrain/speechbrain/utils/Accuracy.py | 103 + .../speechbrain/speechbrain/utils/DER.py | 152 + .../speechbrain/speechbrain/utils/EDER.py | 286 ++ .../speechbrain/speechbrain/utils/__init__.py | 7 + .../speechbrain/utils/_workarounds.py | 36 + .../speechbrain/speechbrain/utils/autocast.py | 252 ++ .../speechbrain/utils/bertscore.py | 351 ++ .../speechbrain/speechbrain/utils/bleu.py | 11 + .../speechbrain/utils/callchains.py | 85 + .../speechbrain/utils/checkpoints.py | 1384 ++++++++ .../speechbrain/utils/data_pipeline.py | 690 ++++ .../speechbrain/utils/data_utils.py | 1262 +++++++ .../speechbrain/speechbrain/utils/depgraph.py | 273 ++ .../speechbrain/utils/dictionaries.py | 122 + .../speechbrain/utils/distances.py | 50 + .../speechbrain/utils/distributed.py | 501 +++ .../utils/dynamic_chunk_training.py | 188 ++ .../speechbrain/utils/edit_distance.py | 797 +++++ .../speechbrain/utils/epoch_loop.py | 201 ++ .../speechbrain/speechbrain/utils/fetching.py | 436 +++ .../speechbrain/utils/filter_analysis.py | 226 ++ .../speechbrain/speechbrain/utils/hparams.py | 37 + .../speechbrain/speechbrain/utils/hpopt.py | 494 +++ .../speechbrain/utils/importutils.py | 309 ++ .../speechbrain/speechbrain/utils/kmeans.py | 229 ++ .../speechbrain/speechbrain/utils/logger.py | 320 ++ .../speechbrain/utils/metric_stats.py | 1425 ++++++++ .../speechbrain/utils/optimizers.py | 37 + .../speechbrain/speechbrain/utils/parallel.py | 346 ++ .../speechbrain/utils/parameter_transfer.py | 350 ++ .../speechbrain/utils/pretrained.py | 96 + .../speechbrain/utils/profiling.py | 40 + .../speechbrain/speechbrain/utils/quirks.py | 123 + .../speechbrain/speechbrain/utils/repro.py | 172 + .../speechbrain/speechbrain/utils/run_opts.py | 363 ++ .../speechbrain/speechbrain/utils/seed.py | 71 + .../speechbrain/speechbrain/utils/semdist.py | 197 ++ .../speechbrain/utils/streaming.py | 235 ++ .../speechbrain/utils/superpowers.py | 87 + .../speechbrain/utils/text_to_sequence.py | 388 +++ .../speechbrain/utils/torch_audio_backend.py | 107 + .../speechbrain/utils/train_logger.py | 484 +++ .../speechbrain/speechbrain/version.txt | 1 + .../audio_fast_lang_id_text/metadata.yml | 67 + .../mapper/audio_fast_lang_id_text/process.py | 157 + .../audio_fast_lang_id_text/requirements.txt | 3 + .../ops/mapper/audio_format_convert/README.md | 26 + .../mapper/audio_format_convert/__init__.py | 6 + .../mapper/audio_format_convert/audio_skip.py | 114 + .../mapper/audio_format_convert/metadata.yml | 61 + .../mapper/audio_format_convert/process.py | 214 ++ .../audio_format_convert/requirements.txt | 3 + .../ops/mapper/audio_gtcrn_denoise/README.md | 24 + .../mapper/audio_gtcrn_denoise/__init__.py | 6 + .../mapper/audio_gtcrn_denoise/audio_skip.py | 114 + .../helpers/utils/color_utils.py | 92 + .../helpers/utils/gtcrn_denoise.py | 349 ++ .../mapper/audio_gtcrn_denoise/metadata.yml | 32 + .../ops/mapper/audio_gtcrn_denoise/process.py | 97 + .../audio_gtcrn_denoise/requirements.txt | 4 + runtime/ops/mapper/audio_hum_notch/README.md | 25 + .../ops/mapper/audio_hum_notch/__init__.py | 6 + .../ops/mapper/audio_hum_notch/audio_skip.py | 114 + .../ops/mapper/audio_hum_notch/metadata.yml | 45 + runtime/ops/mapper/audio_hum_notch/process.py | 105 + .../mapper/audio_hum_notch/requirements.txt | 3 + runtime/ops/mapper/audio_noise_gate/README.md | 27 + .../ops/mapper/audio_noise_gate/__init__.py | 6 + .../ops/mapper/audio_noise_gate/audio_skip.py | 114 + .../ops/mapper/audio_noise_gate/metadata.yml | 58 + .../ops/mapper/audio_noise_gate/process.py | 112 + .../mapper/audio_noise_gate/requirements.txt | 2 + .../ops/mapper/audio_pre_emphasis/README.md | 24 + .../ops/mapper/audio_pre_emphasis/__init__.py | 6 + .../mapper/audio_pre_emphasis/audio_skip.py | 114 + .../mapper/audio_pre_emphasis/metadata.yml | 34 + .../ops/mapper/audio_pre_emphasis/process.py | 101 + .../audio_pre_emphasis/requirements.txt | 2 + .../mapper/audio_quantize_encode/README.md | 26 + .../mapper/audio_quantize_encode/__init__.py | 6 + .../audio_quantize_encode/audio_skip.py | 114 + .../mapper/audio_quantize_encode/metadata.yml | 57 + .../mapper/audio_quantize_encode/process.py | 130 + .../audio_quantize_encode/requirements.txt | 2 + .../audio_rms_loudness_normalize/README.md | 25 + .../audio_rms_loudness_normalize/__init__.py | 6 + .../audio_skip.py | 114 + .../audio_rms_loudness_normalize/metadata.yml | 42 + .../audio_rms_loudness_normalize/process.py | 110 + .../requirements.txt | 2 + runtime/ops/mapper/audio_simple_agc/README.md | 27 + .../ops/mapper/audio_simple_agc/__init__.py | 6 + .../ops/mapper/audio_simple_agc/audio_skip.py | 114 + .../ops/mapper/audio_simple_agc/metadata.yml | 58 + .../ops/mapper/audio_simple_agc/process.py | 114 + .../mapper/audio_simple_agc/requirements.txt | 2 + .../mapper/audio_soft_peak_limiter/README.md | 25 + .../audio_soft_peak_limiter/__init__.py | 6 + .../audio_soft_peak_limiter/audio_skip.py | 114 + .../audio_soft_peak_limiter/metadata.yml | 42 + .../mapper/audio_soft_peak_limiter/process.py | 112 + .../audio_soft_peak_limiter/requirements.txt | 2 + .../ops/mapper/audio_sound_classify/README.md | 34 + .../mapper/audio_sound_classify/__init__.py | 6 + .../mapper/audio_sound_classify/audio_skip.py | 119 + .../local_libs/ast_vendor/__init__.py | 2 + .../local_libs/ast_vendor/ast_models.py | 293 ++ .../metadata/class_labels_indices.csv | 528 +++ .../local_libs/panns_inference/LICENSE.MIT | 21 + .../panns_inference/__init__.py | 4 + .../panns_inference/panns_inference/config.py | 42 + .../panns_inference/inference.py | 170 + .../panns_inference/panns_inference/models.py | 276 ++ .../panns_inference/pytorch_utils.py | 92 + .../mapper/audio_sound_classify/metadata.yml | 138 + .../models/panns/classes_macro_draft.tsv | 528 +++ .../models/recog/audioset_macro_map_v1.json | 133 + .../mapper/audio_sound_classify/process.py | 566 ++++ .../audio_sound_classify/requirements.txt | 8 + .../mapper/audio_telephony_bandpass/README.md | 26 + .../audio_telephony_bandpass/__init__.py | 6 + .../audio_telephony_bandpass/audio_skip.py | 114 + .../audio_telephony_bandpass/metadata.yml | 50 + .../audio_telephony_bandpass/process.py | 113 + .../audio_telephony_bandpass/requirements.txt | 3 + .../ops/mapper/audio_text_summarize/README.md | 37 + .../mapper/audio_text_summarize/__init__.py | 6 + .../mapper/audio_text_summarize/metadata.yml | 119 + .../mapper/audio_text_summarize/process.py | 471 +++ .../audio_text_summarize/requirements.txt | 5 + .../mapper/audio_trim_silence_edges/README.md | 27 + .../audio_trim_silence_edges/__init__.py | 6 + .../audio_trim_silence_edges/audio_skip.py | 114 + .../audio_trim_silence_edges/metadata.yml | 58 + .../audio_trim_silence_edges/process.py | 126 + .../audio_trim_silence_edges/requirements.txt | 2 + 1236 files changed, 401723 insertions(+), 1 deletion(-) create mode 100644 runtime/ops/mapper/audio_anomaly_filter/README.md create mode 100644 runtime/ops/mapper/audio_anomaly_filter/__init__.py create mode 100644 runtime/ops/mapper/audio_anomaly_filter/audio_skip.py create mode 100644 runtime/ops/mapper/audio_anomaly_filter/metadata.yml create mode 100644 runtime/ops/mapper/audio_anomaly_filter/process.py create mode 100644 runtime/ops/mapper/audio_anomaly_filter/requirements.txt create mode 100644 runtime/ops/mapper/audio_asr_pipeline/README.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/audio_config.yaml create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/eval_wer.yaml create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/merge_asr_by_source.yaml create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/alignment/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/alignment/aligner.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/alignment/ctc_segmentation.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/augmenter.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/codec.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/freq_domain.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/preparation.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/time_domain.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/core.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/audio_io.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/batch.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/dataio.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/dataloader.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/dataset.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/iterators.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/legacy.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/preprocess.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/sampler.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/wer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/ctc.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/language_model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/scorer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/seq2seq.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/transducer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/ASR.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/SLU.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/ST.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/TTS.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/VAD.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/classifiers.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/diarization.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/encoders.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/enhancement.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/interfaces.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/interpretability.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/metrics.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/separation.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/speaker.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/text.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/vocoders.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/README.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/README.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/ctc_seg.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/diarization.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/README.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/discrete_ssl.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/kmeans.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/speechtokenizer_interface.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/wavtokenizer_interface.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/decoders/README.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/decoders/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/decoders/kenlm_scorer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/hdf5/README.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/hdf5/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/hdf5/cached_item.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/README.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/encodec.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/gpt.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/hubert.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/huggingface.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/labse.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/llama.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/mbart.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/mert.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/mimi.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/nllb.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/textencoder.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/vocos.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/w2v_bert.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/wav2vec2.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/wavlm.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/weighted_ssl.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/whisper.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/wordemb/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/wordemb/transformer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/huggingface/wordemb/util.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/k2_fsa/README.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/k2_fsa/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/k2_fsa/align.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/k2_fsa/graph_compiler.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/k2_fsa/lattice_decoder.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/k2_fsa/lexicon.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/k2_fsa/losses.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/k2_fsa/prepare_lang.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/k2_fsa/utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/models/README.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/models/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/models/sgmse_plus.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/nlp/README.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/nlp/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/nlp/bgeM3_embeddings.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/nlp/bleu.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/nlp/flair_embeddings.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/nlp/flair_tagger.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/nlp/spacy_pipeline.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/numba/README.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/numba/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/numba/transducer_loss.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/tests/test_cached_item.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/tests/test_ctc_segmentation.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/tests/test_k2.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/tests/test_nlp.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lm/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lm/arpa.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lm/counting.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lm/ngram.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/beamform_multimic.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/downsampling.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/features.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/BESTRQ.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/CRDNN.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/Cnn14.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/ContextNet.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/DiffWave.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/ECAPA_TDNN.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/ESPnetVGG.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/EnhanceResnet.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/FastSpeech2.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/GatedNN.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/HifiGAN.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/L2I.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/MSTacotron2.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/MetricGAN.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/MetricGAN_U.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/PIQ.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/RNNLM.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/ResNet.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/Tacotron2.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/VanillaNN.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/Xvector.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/beats.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/bsq.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/conv_tasnet.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/convolution.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/discrete/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/discrete/dac.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/dual_path.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/fairseq_wav2vec.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/g2p/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/g2p/dataio.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/g2p/homograph.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/g2p/model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/kmeans.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/resepformer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/segan_model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/transformer/Branchformer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/transformer/Conformer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/transformer/Transformer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/transformer/TransformerASR.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/transformer/TransformerLM.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/transformer/TransformerSE.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/transformer/TransformerST.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/transformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/lobes/models/wav2vec.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/log-config.yaml create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/CNN.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/RNN.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/activations.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/adapters.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/attention.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/autoencoders.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/complex_networks/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_CNN.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_RNN.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_linear.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_normalization.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_ops.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/containers.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/diffusion.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/dropout.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/embedding.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/hypermixing.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/linear.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/loss/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/loss/guidedattn_loss.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/loss/si_snr_loss.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/loss/stoi_loss.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/losses.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/normalization.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/pooling.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/quantisers.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_CNN.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_RNN.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_linear.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_normalization.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_ops.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_pooling.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/schedulers.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/transducer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/transducer/transducer_joint.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/unet.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/nnet/utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/processing/NMF.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/processing/PLDA_LDA.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/processing/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/processing/decomposition.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/processing/diarization.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/processing/features.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/processing/multi_mic.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/processing/signal_processing.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/processing/vocal_features.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/tokenizers/SentencePiece.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/tokenizers/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/tokenizers/discrete_SSL_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/Accuracy.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/DER.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/EDER.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/_workarounds.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/autocast.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/bertscore.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/bleu.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/callchains.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/checkpoints.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/data_pipeline.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/data_utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/depgraph.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/dictionaries.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/distances.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/distributed.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/dynamic_chunk_training.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/edit_distance.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/epoch_loop.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/fetching.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/filter_analysis.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/hparams.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/hpopt.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/importutils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/kmeans.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/logger.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/metric_stats.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/optimizers.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/parallel.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/parameter_transfer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/pretrained.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/profiling.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/quirks.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/repro.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/run_opts.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/seed.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/semdist.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/streaming.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/superpowers.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/text_to_sequence.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/torch_audio_backend.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/utils/train_logger.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/version.txt create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/README.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/cli/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/cli/hub.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/cli/model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/cli/punc_model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/cli/transcribe.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/dataset/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/dataset/datapipes.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/dataset/dataset.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/dataset/deprecated/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/dataset/deprecated/dataset.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/dataset/deprecated/processor.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/dataset/kaldi_io.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/dataset/processor.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/dataset/wav_distortion.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/branchformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/branchformer/cgmlp.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/branchformer/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/branchformer/encoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ctl_model/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ctl_model/asr_model_ctl.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ctl_model/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/e_branchformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/e_branchformer/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/e_branchformer/encoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/efficient_conformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/efficient_conformer/attention.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/efficient_conformer/convolution.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/efficient_conformer/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/efficient_conformer/encoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/efficient_conformer/subsampling.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/finetune/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/finetune/lora/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/finetune/lora/config.yaml create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/finetune/lora/layers.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/finetune/lora/utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/firered/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/firered/attention.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/firered/convert_FireRed_AED_L_to_wenet_config_and_ckpt.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/firered/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/firered/encoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/firered/model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/firered/subsampling.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/k2/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/k2/model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/attention.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/cif.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/embedding.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/layers.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/paraformer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/search.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/subsampling.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/sensevoice/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/sensevoice/convert_sensevoice_small_to_wenet_config_and_ckpt.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/sensevoice/sensevoice_small_model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/attention.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/conv2d.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/convolution.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/encoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/positionwise_feed_forward.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/subsampling.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ssl/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ssl/bestrq/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ssl/bestrq/bestrq_model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ssl/bestrq/mask.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ssl/init_dataset.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ssl/init_model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ssl/w2vbert/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ssl/w2vbert/convert_w2vbert_to_wenet_config_and_ckpt.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ssl/w2vbert/w2vbert_model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ssl/wav2vec2/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ssl/wav2vec2/quantizer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/ssl/wav2vec2/wav2vec2_model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transducer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transducer/joint.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transducer/predictor.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transducer/search/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transducer/search/greedy_search.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transducer/search/prefix_beam_search.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transducer/transducer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/asr_model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/attention.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/cmvn.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/convolution.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/ctc.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/decoder.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/decoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/embedding.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/encoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/label_smoothing_loss.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/norm.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/positionwise_feed_forward.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/search.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/subsampling.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/transformer/swish.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/whisper/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/whisper/convert_whisper_to_wenet_config_and_ckpt.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/models/whisper/whisper.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/text/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/text/base_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/text/bpe_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/text/char_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/text/hugging_face_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/text/paraformer_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/text/sentencepiece_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/text/tokenize_utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/text/whisper_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/checkpoint.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/class_utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/cmvn.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/common.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/config.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/context_graph.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/ctc_utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/executor.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/file_utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/fsdp_utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/init_dataset.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/init_model.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/init_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/mask.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/rope_utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/scheduler.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/wenet/wenet/utils/train_utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/__main__.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/audio_convert.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/config_loader.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/scripts/audio_convert/readme.md create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/0_normalization.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/1_denoise.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/2_anomaly_filter.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/3_fast_lang_id.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/4_split_and_tag.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/5_recognize_monitor.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/6_eval_wer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/7_eval_keyword_recall.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/anomaly_filter.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/eval_keyword_recall.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/eval_wer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/merge_asr_by_source.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/normalization.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/recognize_monitor.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/pipeline/split_and_tag.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/audio_anomaly_filter.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/convert_audio.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/gtcrn_denoise.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/readme.txt create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/recognize.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/tools/split_audio.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/color_utils.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/compute_wer.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/fast_lang_id.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/generate_audio_list.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/gtcrn_denoise.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/recognize.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/run_wenet.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/src/utils/yaml_config_loader.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/audio_skip.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/metadata.yml create mode 100644 runtime/ops/mapper/audio_asr_pipeline/process.py create mode 100644 runtime/ops/mapper/audio_asr_pipeline/requirements.txt create mode 100644 runtime/ops/mapper/audio_asr_transcribe/README.md create mode 100644 runtime/ops/mapper/audio_asr_transcribe/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/README.md create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/cli/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/cli/hub.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/cli/model.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/cli/punc_model.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/cli/transcribe.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/dataset/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/dataset/datapipes.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/dataset/dataset.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/dataset/deprecated/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/dataset/deprecated/dataset.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/dataset/deprecated/processor.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/dataset/kaldi_io.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/dataset/processor.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/dataset/wav_distortion.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/branchformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/branchformer/cgmlp.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/branchformer/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/branchformer/encoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ctl_model/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ctl_model/asr_model_ctl.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ctl_model/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/e_branchformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/e_branchformer/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/e_branchformer/encoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/efficient_conformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/efficient_conformer/attention.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/efficient_conformer/convolution.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/efficient_conformer/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/efficient_conformer/encoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/efficient_conformer/subsampling.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/finetune/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/finetune/lora/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/finetune/lora/config.yaml create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/finetune/lora/layers.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/finetune/lora/utils.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/firered/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/firered/attention.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/firered/convert_FireRed_AED_L_to_wenet_config_and_ckpt.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/firered/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/firered/encoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/firered/model.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/firered/subsampling.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/k2/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/k2/model.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/attention.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/cif.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/embedding.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/layers.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/paraformer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/search.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/paraformer/subsampling.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/sensevoice/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/sensevoice/convert_sensevoice_small_to_wenet_config_and_ckpt.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/sensevoice/sensevoice_small_model.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/attention.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/conv2d.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/convolution.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/encoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/positionwise_feed_forward.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/squeezeformer/subsampling.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ssl/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ssl/bestrq/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ssl/bestrq/bestrq_model.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ssl/bestrq/mask.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ssl/init_dataset.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ssl/init_model.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ssl/w2vbert/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ssl/w2vbert/convert_w2vbert_to_wenet_config_and_ckpt.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ssl/w2vbert/w2vbert_model.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ssl/wav2vec2/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ssl/wav2vec2/quantizer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/ssl/wav2vec2/wav2vec2_model.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transducer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transducer/joint.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transducer/predictor.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transducer/search/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transducer/search/greedy_search.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transducer/search/prefix_beam_search.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transducer/transducer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/asr_model.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/attention.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/cmvn.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/convolution.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/ctc.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/decoder.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/decoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/embedding.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/encoder.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/encoder_layer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/label_smoothing_loss.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/norm.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/positionwise_feed_forward.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/search.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/subsampling.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/transformer/swish.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/whisper/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/whisper/convert_whisper_to_wenet_config_and_ckpt.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/models/whisper/whisper.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/text/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/text/base_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/text/bpe_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/text/char_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/text/hugging_face_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/text/paraformer_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/text/sentencepiece_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/text/tokenize_utils.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/text/whisper_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/__init__.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/checkpoint.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/class_utils.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/cmvn.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/common.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/config.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/context_graph.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/ctc_utils.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/executor.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/file_utils.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/fsdp_utils.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/init_dataset.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/init_model.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/init_tokenizer.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/mask.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/rope_utils.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/scheduler.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/local_libs/wenet/wenet/utils/train_utils.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_preprocessor/src/utils/run_wenet.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/audio_skip.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/metadata.yml create mode 100644 runtime/ops/mapper/audio_asr_transcribe/process.py create mode 100644 runtime/ops/mapper/audio_asr_transcribe/requirements.txt create mode 100644 runtime/ops/mapper/audio_dc_offset_removal/README.md create mode 100644 runtime/ops/mapper/audio_dc_offset_removal/__init__.py create mode 100644 runtime/ops/mapper/audio_dc_offset_removal/audio_skip.py create mode 100644 runtime/ops/mapper/audio_dc_offset_removal/metadata.yml create mode 100644 runtime/ops/mapper/audio_dc_offset_removal/process.py create mode 100644 runtime/ops/mapper/audio_dc_offset_removal/requirements.txt create mode 100644 runtime/ops/mapper/audio_emotion_recognize/README.md create mode 100644 runtime/ops/mapper/audio_emotion_recognize/__init__.py create mode 100644 runtime/ops/mapper/audio_emotion_recognize/audio_skip.py create mode 100644 runtime/ops/mapper/audio_emotion_recognize/helpers/utils/emotion_small_model.py create mode 100644 runtime/ops/mapper/audio_emotion_recognize/metadata.yml create mode 100644 runtime/ops/mapper/audio_emotion_recognize/process.py create mode 100644 runtime/ops/mapper/audio_emotion_recognize/requirements.txt create mode 100644 runtime/ops/mapper/audio_fast_lang_id/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/audio_skip.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/helpers/utils/color_utils.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/helpers/utils/compute_wer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/helpers/utils/fast_lang_id.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/helpers/utils/generate_audio_list.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/helpers/utils/gtcrn_denoise.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/helpers/utils/yaml_config_loader.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/alignment/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/alignment/aligner.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/alignment/ctc_segmentation.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/augment/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/augment/augmenter.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/augment/codec.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/augment/freq_domain.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/augment/preparation.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/augment/time_domain.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/core.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/dataio/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/dataio/audio_io.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/dataio/batch.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/dataio/dataio.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/dataio/dataloader.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/dataio/dataset.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/dataio/encoder.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/dataio/iterators.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/dataio/legacy.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/dataio/preprocess.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/dataio/sampler.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/dataio/wer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/decoders/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/decoders/ctc.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/decoders/language_model.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/decoders/scorer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/decoders/seq2seq.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/decoders/transducer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/decoders/utils.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/ASR.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/SLU.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/ST.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/TTS.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/VAD.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/classifiers.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/diarization.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/encoders.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/enhancement.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/interfaces.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/interpretability.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/metrics.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/separation.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/speaker.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/text.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/inference/vocoders.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/alignment/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/alignment/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/alignment/ctc_seg.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/alignment/diarization.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/discrete_ssl.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/kmeans.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/speechtokenizer_interface.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/wavtokenizer_interface.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/decoders/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/decoders/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/decoders/kenlm_scorer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/hdf5/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/hdf5/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/hdf5/cached_item.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/encodec.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/gpt.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/hubert.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/huggingface.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/labse.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/llama.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/mbart.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/mert.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/mimi.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/nllb.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/textencoder.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/vocos.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/w2v_bert.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/wav2vec2.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/wavlm.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/weighted_ssl.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/whisper.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/wordemb/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/wordemb/transformer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/huggingface/wordemb/util.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/k2_fsa/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/k2_fsa/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/k2_fsa/align.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/k2_fsa/graph_compiler.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/k2_fsa/lattice_decoder.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/k2_fsa/lexicon.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/k2_fsa/losses.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/k2_fsa/prepare_lang.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/k2_fsa/utils.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/models/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/models/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/models/sgmse_plus.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/nlp/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/nlp/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/nlp/bgeM3_embeddings.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/nlp/bleu.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/nlp/flair_embeddings.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/nlp/flair_tagger.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/nlp/spacy_pipeline.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/numba/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/numba/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/numba/transducer_loss.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/tests/test_cached_item.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/tests/test_ctc_segmentation.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/tests/test_k2.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/integrations/tests/test_nlp.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lm/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lm/arpa.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lm/counting.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lm/ngram.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/beamform_multimic.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/downsampling.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/features.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/BESTRQ.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/CRDNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/Cnn14.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/ContextNet.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/DiffWave.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/ECAPA_TDNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/ESPnetVGG.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/EnhanceResnet.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/FastSpeech2.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/GatedNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/HifiGAN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/L2I.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/MSTacotron2.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/MetricGAN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/MetricGAN_U.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/PIQ.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/RNNLM.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/ResNet.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/Tacotron2.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/VanillaNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/Xvector.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/beats.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/bsq.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/conv_tasnet.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/convolution.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/discrete/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/discrete/dac.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/dual_path.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/fairseq_wav2vec.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/g2p/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/g2p/dataio.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/g2p/homograph.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/g2p/model.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/kmeans.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/resepformer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/segan_model.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/transformer/Branchformer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/transformer/Conformer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/transformer/Transformer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/transformer/TransformerASR.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/transformer/TransformerLM.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/transformer/TransformerSE.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/transformer/TransformerST.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/transformer/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/lobes/models/wav2vec.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/log-config.yaml create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/CNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/RNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/activations.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/adapters.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/attention.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/autoencoders.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/complex_networks/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_CNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_RNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_linear.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_normalization.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_ops.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/containers.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/diffusion.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/dropout.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/embedding.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/hypermixing.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/linear.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/loss/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/loss/guidedattn_loss.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/loss/si_snr_loss.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/loss/stoi_loss.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/losses.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/normalization.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/pooling.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/quantisers.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_CNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_RNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_linear.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_normalization.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_ops.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_pooling.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/schedulers.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/transducer/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/transducer/transducer_joint.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/unet.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/nnet/utils.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/processing/NMF.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/processing/PLDA_LDA.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/processing/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/processing/decomposition.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/processing/diarization.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/processing/features.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/processing/multi_mic.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/processing/signal_processing.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/processing/vocal_features.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/tokenizers/SentencePiece.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/tokenizers/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/tokenizers/discrete_SSL_tokenizer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/Accuracy.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/DER.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/EDER.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/_workarounds.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/autocast.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/bertscore.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/bleu.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/callchains.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/checkpoints.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/data_pipeline.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/data_utils.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/depgraph.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/dictionaries.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/distances.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/distributed.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/dynamic_chunk_training.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/edit_distance.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/epoch_loop.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/fetching.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/filter_analysis.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/hparams.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/hpopt.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/importutils.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/kmeans.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/logger.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/metric_stats.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/optimizers.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/parallel.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/parameter_transfer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/pretrained.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/profiling.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/quirks.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/repro.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/run_opts.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/seed.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/semdist.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/streaming.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/superpowers.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/text_to_sequence.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/torch_audio_backend.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/utils/train_logger.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/local_libs/speechbrain/speechbrain/version.txt create mode 100644 runtime/ops/mapper/audio_fast_lang_id/metadata.yml create mode 100644 runtime/ops/mapper/audio_fast_lang_id/process.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id/requirements.txt create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/audio_skip.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/helpers/utils/color_utils.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/helpers/utils/compute_wer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/helpers/utils/fast_lang_id.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/helpers/utils/generate_audio_list.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/helpers/utils/gtcrn_denoise.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/helpers/utils/yaml_config_loader.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/alignment/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/alignment/aligner.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/alignment/ctc_segmentation.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/augment/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/augment/augmenter.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/augment/codec.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/augment/freq_domain.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/augment/preparation.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/augment/time_domain.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/core.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/dataio/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/dataio/audio_io.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/dataio/batch.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/dataio/dataio.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/dataio/dataloader.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/dataio/dataset.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/dataio/encoder.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/dataio/iterators.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/dataio/legacy.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/dataio/preprocess.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/dataio/sampler.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/dataio/wer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/decoders/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/decoders/ctc.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/decoders/language_model.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/decoders/scorer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/decoders/seq2seq.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/decoders/transducer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/decoders/utils.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/ASR.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/SLU.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/ST.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/TTS.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/VAD.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/classifiers.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/diarization.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/encoders.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/enhancement.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/interfaces.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/interpretability.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/metrics.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/separation.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/speaker.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/text.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/inference/vocoders.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/alignment/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/alignment/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/alignment/ctc_seg.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/alignment/diarization.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/discrete_ssl.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/kmeans.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/speechtokenizer_interface.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/audio_tokenizers/wavtokenizer_interface.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/decoders/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/decoders/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/decoders/kenlm_scorer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/hdf5/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/hdf5/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/hdf5/cached_item.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/encodec.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/gpt.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/hubert.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/huggingface.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/labse.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/llama.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/mbart.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/mert.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/mimi.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/nllb.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/textencoder.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/vocos.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/w2v_bert.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/wav2vec2.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/wavlm.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/weighted_ssl.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/whisper.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/wordemb/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/wordemb/transformer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/huggingface/wordemb/util.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/k2_fsa/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/k2_fsa/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/k2_fsa/align.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/k2_fsa/graph_compiler.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/k2_fsa/lattice_decoder.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/k2_fsa/lexicon.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/k2_fsa/losses.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/k2_fsa/prepare_lang.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/k2_fsa/utils.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/models/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/models/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/models/sgmse_plus.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/nlp/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/nlp/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/nlp/bgeM3_embeddings.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/nlp/bleu.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/nlp/flair_embeddings.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/nlp/flair_tagger.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/nlp/spacy_pipeline.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/numba/README.md create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/numba/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/numba/transducer_loss.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/tests/test_cached_item.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/tests/test_ctc_segmentation.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/tests/test_k2.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/integrations/tests/test_nlp.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lm/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lm/arpa.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lm/counting.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lm/ngram.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/beamform_multimic.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/downsampling.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/features.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/BESTRQ.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/CRDNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/Cnn14.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/ContextNet.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/DiffWave.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/ECAPA_TDNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/ESPnetVGG.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/EnhanceResnet.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/FastSpeech2.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/GatedNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/HifiGAN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/L2I.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/MSTacotron2.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/MetricGAN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/MetricGAN_U.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/PIQ.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/RNNLM.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/ResNet.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/Tacotron2.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/VanillaNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/Xvector.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/beats.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/bsq.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/conv_tasnet.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/convolution.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/discrete/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/discrete/dac.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/dual_path.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/fairseq_wav2vec.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/g2p/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/g2p/dataio.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/g2p/homograph.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/g2p/model.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/kmeans.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/resepformer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/segan_model.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/transformer/Branchformer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/transformer/Conformer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/transformer/Transformer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/transformer/TransformerASR.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/transformer/TransformerLM.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/transformer/TransformerSE.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/transformer/TransformerST.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/transformer/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/lobes/models/wav2vec.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/log-config.yaml create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/CNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/RNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/activations.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/adapters.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/attention.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/autoencoders.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/complex_networks/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_CNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_RNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_linear.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_normalization.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/complex_networks/c_ops.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/containers.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/diffusion.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/dropout.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/embedding.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/hypermixing.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/linear.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/loss/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/loss/guidedattn_loss.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/loss/si_snr_loss.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/loss/stoi_loss.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/losses.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/normalization.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/pooling.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/quantisers.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_CNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_RNN.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_linear.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_normalization.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_ops.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/quaternion_networks/q_pooling.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/schedulers.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/transducer/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/transducer/transducer_joint.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/unet.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/nnet/utils.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/processing/NMF.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/processing/PLDA_LDA.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/processing/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/processing/decomposition.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/processing/diarization.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/processing/features.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/processing/multi_mic.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/processing/signal_processing.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/processing/vocal_features.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/tokenizers/SentencePiece.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/tokenizers/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/tokenizers/discrete_SSL_tokenizer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/Accuracy.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/DER.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/EDER.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/__init__.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/_workarounds.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/autocast.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/bertscore.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/bleu.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/callchains.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/checkpoints.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/data_pipeline.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/data_utils.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/depgraph.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/dictionaries.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/distances.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/distributed.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/dynamic_chunk_training.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/edit_distance.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/epoch_loop.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/fetching.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/filter_analysis.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/hparams.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/hpopt.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/importutils.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/kmeans.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/logger.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/metric_stats.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/optimizers.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/parallel.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/parameter_transfer.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/pretrained.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/profiling.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/quirks.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/repro.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/run_opts.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/seed.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/semdist.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/streaming.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/superpowers.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/text_to_sequence.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/torch_audio_backend.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/utils/train_logger.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/local_libs/speechbrain/speechbrain/version.txt create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/metadata.yml create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/process.py create mode 100644 runtime/ops/mapper/audio_fast_lang_id_text/requirements.txt create mode 100644 runtime/ops/mapper/audio_format_convert/README.md create mode 100644 runtime/ops/mapper/audio_format_convert/__init__.py create mode 100644 runtime/ops/mapper/audio_format_convert/audio_skip.py create mode 100644 runtime/ops/mapper/audio_format_convert/metadata.yml create mode 100644 runtime/ops/mapper/audio_format_convert/process.py create mode 100644 runtime/ops/mapper/audio_format_convert/requirements.txt create mode 100644 runtime/ops/mapper/audio_gtcrn_denoise/README.md create mode 100644 runtime/ops/mapper/audio_gtcrn_denoise/__init__.py create mode 100644 runtime/ops/mapper/audio_gtcrn_denoise/audio_skip.py create mode 100644 runtime/ops/mapper/audio_gtcrn_denoise/helpers/utils/color_utils.py create mode 100644 runtime/ops/mapper/audio_gtcrn_denoise/helpers/utils/gtcrn_denoise.py create mode 100644 runtime/ops/mapper/audio_gtcrn_denoise/metadata.yml create mode 100644 runtime/ops/mapper/audio_gtcrn_denoise/process.py create mode 100644 runtime/ops/mapper/audio_gtcrn_denoise/requirements.txt create mode 100644 runtime/ops/mapper/audio_hum_notch/README.md create mode 100644 runtime/ops/mapper/audio_hum_notch/__init__.py create mode 100644 runtime/ops/mapper/audio_hum_notch/audio_skip.py create mode 100644 runtime/ops/mapper/audio_hum_notch/metadata.yml create mode 100644 runtime/ops/mapper/audio_hum_notch/process.py create mode 100644 runtime/ops/mapper/audio_hum_notch/requirements.txt create mode 100644 runtime/ops/mapper/audio_noise_gate/README.md create mode 100644 runtime/ops/mapper/audio_noise_gate/__init__.py create mode 100644 runtime/ops/mapper/audio_noise_gate/audio_skip.py create mode 100644 runtime/ops/mapper/audio_noise_gate/metadata.yml create mode 100644 runtime/ops/mapper/audio_noise_gate/process.py create mode 100644 runtime/ops/mapper/audio_noise_gate/requirements.txt create mode 100644 runtime/ops/mapper/audio_pre_emphasis/README.md create mode 100644 runtime/ops/mapper/audio_pre_emphasis/__init__.py create mode 100644 runtime/ops/mapper/audio_pre_emphasis/audio_skip.py create mode 100644 runtime/ops/mapper/audio_pre_emphasis/metadata.yml create mode 100644 runtime/ops/mapper/audio_pre_emphasis/process.py create mode 100644 runtime/ops/mapper/audio_pre_emphasis/requirements.txt create mode 100644 runtime/ops/mapper/audio_quantize_encode/README.md create mode 100644 runtime/ops/mapper/audio_quantize_encode/__init__.py create mode 100644 runtime/ops/mapper/audio_quantize_encode/audio_skip.py create mode 100644 runtime/ops/mapper/audio_quantize_encode/metadata.yml create mode 100644 runtime/ops/mapper/audio_quantize_encode/process.py create mode 100644 runtime/ops/mapper/audio_quantize_encode/requirements.txt create mode 100644 runtime/ops/mapper/audio_rms_loudness_normalize/README.md create mode 100644 runtime/ops/mapper/audio_rms_loudness_normalize/__init__.py create mode 100644 runtime/ops/mapper/audio_rms_loudness_normalize/audio_skip.py create mode 100644 runtime/ops/mapper/audio_rms_loudness_normalize/metadata.yml create mode 100644 runtime/ops/mapper/audio_rms_loudness_normalize/process.py create mode 100644 runtime/ops/mapper/audio_rms_loudness_normalize/requirements.txt create mode 100644 runtime/ops/mapper/audio_simple_agc/README.md create mode 100644 runtime/ops/mapper/audio_simple_agc/__init__.py create mode 100644 runtime/ops/mapper/audio_simple_agc/audio_skip.py create mode 100644 runtime/ops/mapper/audio_simple_agc/metadata.yml create mode 100644 runtime/ops/mapper/audio_simple_agc/process.py create mode 100644 runtime/ops/mapper/audio_simple_agc/requirements.txt create mode 100644 runtime/ops/mapper/audio_soft_peak_limiter/README.md create mode 100644 runtime/ops/mapper/audio_soft_peak_limiter/__init__.py create mode 100644 runtime/ops/mapper/audio_soft_peak_limiter/audio_skip.py create mode 100644 runtime/ops/mapper/audio_soft_peak_limiter/metadata.yml create mode 100644 runtime/ops/mapper/audio_soft_peak_limiter/process.py create mode 100644 runtime/ops/mapper/audio_soft_peak_limiter/requirements.txt create mode 100644 runtime/ops/mapper/audio_sound_classify/README.md create mode 100644 runtime/ops/mapper/audio_sound_classify/__init__.py create mode 100644 runtime/ops/mapper/audio_sound_classify/audio_skip.py create mode 100644 runtime/ops/mapper/audio_sound_classify/local_libs/ast_vendor/__init__.py create mode 100644 runtime/ops/mapper/audio_sound_classify/local_libs/ast_vendor/ast_models.py create mode 100644 runtime/ops/mapper/audio_sound_classify/local_libs/audioset_tagging_cnn/metadata/class_labels_indices.csv create mode 100644 runtime/ops/mapper/audio_sound_classify/local_libs/panns_inference/LICENSE.MIT create mode 100644 runtime/ops/mapper/audio_sound_classify/local_libs/panns_inference/panns_inference/__init__.py create mode 100644 runtime/ops/mapper/audio_sound_classify/local_libs/panns_inference/panns_inference/config.py create mode 100644 runtime/ops/mapper/audio_sound_classify/local_libs/panns_inference/panns_inference/inference.py create mode 100644 runtime/ops/mapper/audio_sound_classify/local_libs/panns_inference/panns_inference/models.py create mode 100644 runtime/ops/mapper/audio_sound_classify/local_libs/panns_inference/panns_inference/pytorch_utils.py create mode 100644 runtime/ops/mapper/audio_sound_classify/metadata.yml create mode 100644 runtime/ops/mapper/audio_sound_classify/models/panns/classes_macro_draft.tsv create mode 100644 runtime/ops/mapper/audio_sound_classify/models/recog/audioset_macro_map_v1.json create mode 100644 runtime/ops/mapper/audio_sound_classify/process.py create mode 100644 runtime/ops/mapper/audio_sound_classify/requirements.txt create mode 100644 runtime/ops/mapper/audio_telephony_bandpass/README.md create mode 100644 runtime/ops/mapper/audio_telephony_bandpass/__init__.py create mode 100644 runtime/ops/mapper/audio_telephony_bandpass/audio_skip.py create mode 100644 runtime/ops/mapper/audio_telephony_bandpass/metadata.yml create mode 100644 runtime/ops/mapper/audio_telephony_bandpass/process.py create mode 100644 runtime/ops/mapper/audio_telephony_bandpass/requirements.txt create mode 100644 runtime/ops/mapper/audio_text_summarize/README.md create mode 100644 runtime/ops/mapper/audio_text_summarize/__init__.py create mode 100644 runtime/ops/mapper/audio_text_summarize/metadata.yml create mode 100644 runtime/ops/mapper/audio_text_summarize/process.py create mode 100644 runtime/ops/mapper/audio_text_summarize/requirements.txt create mode 100644 runtime/ops/mapper/audio_trim_silence_edges/README.md create mode 100644 runtime/ops/mapper/audio_trim_silence_edges/__init__.py create mode 100644 runtime/ops/mapper/audio_trim_silence_edges/audio_skip.py create mode 100644 runtime/ops/mapper/audio_trim_silence_edges/metadata.yml create mode 100644 runtime/ops/mapper/audio_trim_silence_edges/process.py create mode 100644 runtime/ops/mapper/audio_trim_silence_edges/requirements.txt diff --git a/runtime/ops/mapper/__init__.py b/runtime/ops/mapper/__init__.py index ed0a0fcb2..193db4913 100644 --- a/runtime/ops/mapper/__init__.py +++ b/runtime/ops/mapper/__init__.py @@ -47,7 +47,30 @@ def _import_operators(): from . import remove_duplicate_sentences from . import knowledge_relation_slice from . import pii_ner_detection - # ===== Video operators (PR1-PR5) ===== + + # ===== Audio operators ===== + from . import audio_anomaly_filter + from . import audio_asr_pipeline + from . import audio_asr_transcribe + from . import audio_dc_offset_removal + from . import audio_emotion_recognize + from . import audio_fast_lang_id + from . import audio_fast_lang_id_text + from . import audio_format_convert + from . import audio_gtcrn_denoise + from . import audio_hum_notch + from . import audio_noise_gate + from . import audio_pre_emphasis + from . import audio_quantize_encode + from . import audio_rms_loudness_normalize + from . import audio_simple_agc + from . import audio_soft_peak_limiter + from . import audio_sound_classify + from . import audio_telephony_bandpass + from . import audio_text_summarize + from . import audio_trim_silence_edges + + # ===== Video operators (PR1-PR5) ===== from . import _video_common from . import video_format_convert from . import video_sensitive_detect diff --git a/runtime/ops/mapper/audio_anomaly_filter/README.md b/runtime/ops/mapper/audio_anomaly_filter/README.md new file mode 100644 index 000000000..fab93a764 --- /dev/null +++ b/runtime/ops/mapper/audio_anomaly_filter/README.md @@ -0,0 +1,41 @@ +# AudioAnomalyFilter 异常语音检测与过滤算子 + +## 概述 + +AudioAnomalyFilter 用于对音频做快速质量检测,计算时长、静音帧比例与音频可读性,并给出 `quality_flag`。算子不再通过清空 `text/data` 模拟删除文件,而是写入结构化质量标签;下游音频算子可根据标签软跳过异常样本。 + +## 功能特性 + +- **时长检测**:支持最小时长/最大时长阈值 +- **静音比例检测**:基于短时 RMS 统计静音帧占比 +- **可读性检测**:文本文件强行改成 `.wav` 等不可读取音频会被标记为 `invalid` +- **下游门控**:支持让后续音频算子跳过异常样本,符合 DataMate 一文件一输出链路 +- **结果结构化输出**:报告写入 `ext_params.audio_quality` + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| minDur | inputNumber | 1.0 | 最小时长(秒),小于该值视为异常 | +| maxDur | inputNumber | 20000.0 | 最大时长(秒),大于该值视为异常 | +| silenceRatioTh | slider | 0.8 | 静音帧比例阈值(0~1),>= 阈值视为异常 | +| silenceRmsRatioTh | slider | 0.05 | 静音判定阈值 = global_rms * 该比例 | +| skipInvalidDownstream | switch | true | true=后续音频算子遇到 invalid 软跳过;false=仅打标并继续处理 | + +## 输入输出 + +- **输入**:`sample["filePath"]`(音频文件路径) +- **输出**: + - `sample["ext_params"]["audio_quality"]`: + - `quality_flag`: `ok/invalid` + - `duration/silence_ratio/global_rms/reason/read_error/skip_downstream` + - 如果该算子为链路最后一个算子:导出当前音频,质量报告写入 `ext_params.audio_quality` + - 如果该算子位于链路中间:保持当前音频,后续音频算子按 `skip_downstream` 决定是否软跳过 + +## 依赖说明 + +- **Python 依赖**:优先 `torchaudio`,兜底 `soundfile` + +## 版本历史 + +- **v1.0.0**:支持时长/静音比例/可读性检测,按 DataMate 链路语义写质量标签并门控下游 diff --git a/runtime/ops/mapper/audio_anomaly_filter/__init__.py b/runtime/ops/mapper/audio_anomaly_filter/__init__.py new file mode 100644 index 000000000..fb9b45218 --- /dev/null +++ b/runtime/ops/mapper/audio_anomaly_filter/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioAnomalyFilter', + module_path="ops.mapper.audio_anomaly_filter.process") diff --git a/runtime/ops/mapper/audio_anomaly_filter/audio_skip.py b/runtime/ops/mapper/audio_anomaly_filter/audio_skip.py new file mode 100644 index 000000000..aec496137 --- /dev/null +++ b/runtime/ops/mapper/audio_anomaly_filter/audio_skip.py @@ -0,0 +1,114 @@ +# -- encoding: utf-8 -- + +from pathlib import Path +from typing import Any, Dict + +from loguru import logger + + +AUDIO_EXTS = { + "aac", + "aif", + "aiff", + "amr", + "au", + "flac", + "m4a", + "mp3", + "oga", + "ogg", + "opus", + "snd", + "wav", + "webm", + "wma", +} + + +def _parts(path_value: str) -> set[str]: + try: + return {part.lower() for part in Path(path_value).parts} + except Exception: + return set() + + +def is_reference_sample(sample: Dict[str, Any], filepath_key: str = "filePath") -> bool: + path_value = str(sample.get(filepath_key) or "") + return "references" in _parts(path_value) + + +def _ext_from_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", +) -> str: + for key in (target_type_key, filetype_key): + value = str(sample.get(key) or "").strip().lower().lstrip(".") + if value: + return value + path_value = str(sample.get(filepath_key) or "").strip() + return Path(path_value).suffix.lower().lstrip(".") if path_value else "" + + +def is_audio_sample( + sample: Dict[str, Any], + filepath_key: str = "filePath", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + data_key: str = "data", +) -> bool: + if is_reference_sample(sample, filepath_key): + return False + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return True + return _ext_from_sample(sample, filepath_key, filetype_key, target_type_key) in AUDIO_EXTS + + +def invalid_quality_reason(sample: Dict[str, Any], ext_params_key: str = "ext_params") -> str: + for key in ("fileName", "sourceFileName", "filePath"): + marker_source = Path(str(sample.get(key) or "")).stem.lower() + marker = "__quality_invalid" + if marker in marker_source: + reason = marker_source.split(marker, 1)[1].strip("_") or "invalid_audio" + return f"invalid_audio_quality:{reason}" + + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + return "" + quality = ext.get("audio_quality", {}) + if not isinstance(quality, dict): + return "" + if str(quality.get("quality_flag") or "").strip().lower() != "invalid": + return "" + skip_downstream = quality.get("skip_downstream", True) + if isinstance(skip_downstream, str): + skip_downstream = skip_downstream.strip().lower() in {"1", "true", "yes", "y", "on"} + if not skip_downstream: + return "" + reason = str(quality.get("reason") or "invalid_audio").strip() + return f"invalid_audio_quality:{reason}" + + +def mark_skipped_sample( + sample: Dict[str, Any], + reason: str, + op_name: str, + text_key: str = "text", + data_key: str = "data", + filetype_key: str = "fileType", + target_type_key: str = "target_type", + ext_params_key: str = "ext_params", +) -> Dict[str, Any]: + ext = sample.get(ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext.setdefault("audio_skip", {})[op_name] = reason + sample[ext_params_key] = ext + sample[text_key] = "" + sample[data_key] = b"" + sample[filetype_key] = "" + sample[target_type_key] = "" + logger.info(f"fileName: {sample.get('fileName')}, method: {op_name} skipped: {reason}") + return sample diff --git a/runtime/ops/mapper/audio_anomaly_filter/metadata.yml b/runtime/ops/mapper/audio_anomaly_filter/metadata.yml new file mode 100644 index 000000000..7f0d9394a --- /dev/null +++ b/runtime/ops/mapper/audio_anomaly_filter/metadata.yml @@ -0,0 +1,66 @@ +name: 'audioOps-异常语音检测与过滤' +name_en: 'audioOps-Audio Anomaly Detect & Filter' +description: '对音频做快速异常检测:时长范围、静音帧比例与可读性。结果写入 ext_params.audio_quality;可控制下游音频算子是否跳过异常样本。' +description_en: 'Fast audio anomaly detection (duration, silence ratio and readability). Writes ext_params.audio_quality and can make downstream audio ops skip invalid samples.' +language: 'python' +vendor: 'huawei' +raw_id: 'AudioAnomalyFilter' +version: '1.0.0' +types: + - 'cleaning' +modal: 'audio' +inputs: 'audio' +outputs: 'audio' +settings: + minDur: + name: '最小时长(秒)' + type: 'inputNumber' + description: '小于该值视为异常。' + defaultVal: 1.0 + min: 0 + max: 36000 + step: 0.1 + maxDur: + name: '最大时长(秒)' + type: 'inputNumber' + description: '大于该值视为异常。' + defaultVal: 20000.0 + min: 0 + max: 360000 + step: 1 + silenceRatioTh: + name: '静音帧比例阈值' + type: 'slider' + description: '静音帧比例 >= 阈值 时视为异常。' + defaultVal: 0.8 + min: 0 + max: 1 + step: 0.01 + silenceRmsRatioTh: + name: '静音判定比例' + type: 'slider' + description: '静音判定阈值 = global_rms * 该比例。' + defaultVal: 0.05 + min: 0 + max: 1 + step: 0.01 + skipInvalidDownstream: + name: '下游跳过异常音频' + description: '开启后,后续音频算子遇到 quality_flag=invalid 会软跳过;关闭后仅打标并继续处理。不可读取的伪 wav 会被标为 invalid。' + type: 'switch' + defaultVal: 'true' + required: false + checkedLabel: '跳过' + unCheckedLabel: '继续' +runtime: + memory: 104857600 + cpu: 0.2 + gpu: 0 + npu: 0 + storage: 10MB + +metrics: + - name: '处理耗时' + metric: '依输入音频长度与运行环境而定' +release: + - '首次发布' diff --git a/runtime/ops/mapper/audio_anomaly_filter/process.py b/runtime/ops/mapper/audio_anomaly_filter/process.py new file mode 100644 index 000000000..5d9cb278c --- /dev/null +++ b/runtime/ops/mapper/audio_anomaly_filter/process.py @@ -0,0 +1,221 @@ +# -- encoding: utf-8 -- + +import math +import re +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from loguru import logger + +from datamate.core.base_op import Mapper + +try: + from .audio_skip import is_audio_sample, mark_skipped_sample +except ImportError: + from audio_skip import is_audio_sample, mark_skipped_sample + + +def _as_bool(v: object) -> bool: + if isinstance(v, bool): + return v + return str(v).strip().lower() in {"1", "true", "yes", "y", "on"} + + +def _audio_ext(sample: Dict[str, Any], default_ext: str = "wav") -> str: + for key in ("target_type", "fileType"): + ext = str(sample.get(key) or "").strip().lower().lstrip(".") + if ext: + return ext + path_value = str(sample.get("filePath") or "").strip() + suffix = Path(path_value).suffix.lower().lstrip(".") if path_value else "" + return suffix or default_ext + + +def _source_audio_bytes(sample: Dict[str, Any], data_key: str, filepath_key: str, read_file: bool = False) -> bytes: + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + return bytes(data) + if not read_file: + return b"" + path = Path(str(sample.get(filepath_key) or "")).expanduser() + if path.exists() and path.is_file(): + return path.read_bytes() + return b"" + + +def _safe_marker(value: str, default: str = "invalid_audio") -> str: + marker = re.sub(r"[^A-Za-z0-9._-]+", "_", str(value or default)).strip("._-") + return marker[:80] or default + + +def _strip_quality_marker(stem: str) -> str: + return re.sub(r"__quality_invalid(?:_[A-Za-z0-9._-]+)?$", "", str(stem or "sample")) + + +def _mark_quality_filename(sample: Dict[str, Any], filename_key: str, reason: str, target_ext: str) -> None: + file_name = str(sample.get(filename_key) or "").strip() + stem = _strip_quality_marker(Path(file_name).stem if file_name else "sample") + sample[filename_key] = f"{stem}__quality_invalid_{_safe_marker(reason)}.{target_ext}" + + +def _load_wave_mono(path: Path) -> Tuple[List[float], int]: + try: + import torchaudio # type: ignore + + wav, sr = torchaudio.load(str(path)) + if wav.ndim > 1: + wav = wav.mean(dim=0, keepdim=True) + return wav.squeeze(0).float().tolist(), int(sr) + except Exception: + try: + import soundfile as sf # type: ignore + + data, sr = sf.read(str(path), always_2d=False) + if getattr(data, "ndim", 1) > 1: + data = data.mean(axis=1) + return data.tolist(), int(sr) + except Exception as e: + raise RuntimeError(f"failed to read audio: {path}, error={e}") from e + + +def _load_source_mono(sample: Dict[str, Any], data_key: str, filepath_key: str) -> Tuple[List[float], int]: + data = sample.get(data_key) + if isinstance(data, (bytes, bytearray)) and data: + with tempfile.NamedTemporaryFile(suffix=f".{_audio_ext(sample)}", delete=False) as tmp: + tmp.write(bytes(data)) + tmp_path = Path(tmp.name) + try: + return _load_wave_mono(tmp_path) + finally: + try: + tmp_path.unlink() + except Exception: + pass + return _load_wave_mono(Path(str(sample.get(filepath_key) or "")).expanduser().resolve()) + + +def _frame_rms(x: List[float], sr: int, frame_ms: float, hop_ms: float) -> Tuple[List[float], float]: + if not x or sr <= 0: + return [], 0.0 + frame_len = max(1, int(sr * frame_ms / 1000.0)) + hop = max(1, int(sr * hop_ms / 1000.0)) + total_sq = sum(float(v) * float(v) for v in x) + global_rms = math.sqrt(total_sq / max(1, len(x))) + rms_list: List[float] = [] + for start in range(0, len(x), hop): + end = min(start + frame_len, len(x)) + if end <= start: + continue + frame = x[start:end] + rms_list.append(math.sqrt(sum(float(v) * float(v) for v in frame) / max(1, len(frame)))) + return rms_list, global_rms + + +class AudioAnomalyFilter(Mapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.min_dur = float(kwargs.get("minDur", 1.0)) + self.max_dur = float(kwargs.get("maxDur", 20000.0)) + self.silence_ratio_th = float(kwargs.get("silenceRatioTh", 0.8)) + self.silence_rms_ratio_th = float(kwargs.get("silenceRmsRatioTh", 0.05)) + self.skip_invalid_downstream = _as_bool(kwargs.get("skipInvalidDownstream", True)) + + def execute(self, sample: Dict[str, Any]) -> Dict[str, Any]: + start = time.time() + if not is_audio_sample(sample, self.filepath_key, self.filetype_key, self.target_type_key, self.data_key): + return mark_skipped_sample( + sample, + "non_audio_or_reference_file", + self.__class__.__name__, + self.text_key, + self.data_key, + self.filetype_key, + self.target_type_key, + self.ext_params_key, + ) + + audio_bytes_for_export = _source_audio_bytes(sample, self.data_key, self.filepath_key) + path_value = str(sample.get(self.filepath_key) or "").strip() + path_exists = bool(audio_bytes_for_export) or (bool(path_value) and Path(path_value).expanduser().exists()) + reasons: List[str] = [] + quality_flag = "ok" + read_error = "" + + if not path_exists: + duration = 0.0 + silence_ratio = 1.0 + global_rms = 0.0 + quality_flag = "invalid" + read_error = f"FileNotFoundError: input audio does not exist: {sample.get(self.filepath_key)}" + reasons.append("missing_audio_file") + else: + try: + wav, sr = _load_source_mono(sample, self.data_key, self.filepath_key) + duration = float(len(wav)) / float(sr) if sr > 0 else 0.0 + rms_frames, global_rms = _frame_rms(wav, sr, frame_ms=25.0, hop_ms=10.0) + if not rms_frames or global_rms <= 0.0: + silence_ratio = 1.0 + else: + threshold = max(1e-8, global_rms * float(self.silence_rms_ratio_th)) + silent = sum(1 for rms in rms_frames if rms < threshold) + silence_ratio = float(silent) / float(len(rms_frames)) + except Exception as e: + duration = 0.0 + silence_ratio = 1.0 + global_rms = 0.0 + quality_flag = "invalid" + read_error = f"{type(e).__name__}: {e}" + reasons.append("unreadable_audio") + + if duration <= 0.0: + quality_flag = "invalid" + if "duration_le_zero" not in reasons: + reasons.append("duration_le_zero") + elif duration < self.min_dur: + quality_flag = "invalid" + reasons.append("too_short") + elif duration > self.max_dur: + quality_flag = "invalid" + reasons.append("too_long") + if silence_ratio >= self.silence_ratio_th: + quality_flag = "invalid" + reasons.append("too_much_silence") + + report = { + "quality_flag": quality_flag, + "duration": round(duration, 3), + "silence_ratio": round(silence_ratio, 4), + "global_rms": round(global_rms, 6), + "reason": ",".join(reasons) if reasons else "", + "read_error": read_error, + "skip_downstream": self.skip_invalid_downstream, + } + ext = sample.get(self.ext_params_key, {}) + if not isinstance(ext, dict): + ext = {"_raw": ext} + ext["audio_quality"] = report + sample[self.ext_params_key] = ext + + sample[self.text_key] = "" + if self.is_last_op and not audio_bytes_for_export: + audio_bytes_for_export = _source_audio_bytes( + sample, + self.data_key, + self.filepath_key, + read_file=True, + ) + if audio_bytes_for_export: + sample[self.data_key] = audio_bytes_for_export + if self.is_last_op: + target_ext = _audio_ext(sample) + sample[self.filetype_key] = "txt" + sample[self.target_type_key] = target_ext + if quality_flag == "invalid": + _mark_quality_filename(sample, self.filename_key, report["reason"] or "invalid_audio", target_ext) + + logger.info( + f"fileName: {sample.get(self.filename_key)}, method: AudioAnomalyFilter costs {time.time() - start:6f} s" + ) + return sample diff --git a/runtime/ops/mapper/audio_anomaly_filter/requirements.txt b/runtime/ops/mapper/audio_anomaly_filter/requirements.txt new file mode 100644 index 000000000..fd0cf60b5 --- /dev/null +++ b/runtime/ops/mapper/audio_anomaly_filter/requirements.txt @@ -0,0 +1,2 @@ +torchaudio +soundfile diff --git a/runtime/ops/mapper/audio_asr_pipeline/README.md b/runtime/ops/mapper/audio_asr_pipeline/README.md new file mode 100644 index 000000000..84823bde0 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/README.md @@ -0,0 +1,62 @@ +# AudioAsrPipeline 音频预处理与中英ASR流水线算子 + +## 概述 + +AudioAsrPipeline 将 `audio_preprocessor` 的推荐流水线封装为一个 DataMate Mapper 算子:标准化、(可选)降噪、(可选)异常过滤、语言识别、切分、ASR 识别与合并,并可选计算中英文关键词召回率。算子按 DataMate 单样本范式处理当前输入音频,最终只导出该输入文件对应的一个 `.txt` 转写文件,并在 `ext_params` 中记录中间产物路径,便于排查与验收。 + +## 功能特性 + +- **端到端流水线**:normalization →(可选)GTCRN →(可选)异常过滤 → LID → split → ASR → merge →(可选)关键词召回率 +- **可配置**:每个关键步骤参数化(降噪开关、过滤阈值、LID 截断秒数、切分长度、ASR 设备等) +- **结果可追溯**:中间产物路径记录在 `ext_params.audio_asr.artifacts` +- **关键词召回率**:复用 `audio_preprocessor/src/pipeline/eval_keyword_recall.py`,生成 `keyword_recall.txt` 并写入导出目录 +- **一入一出**:每个输入音频输出一个 `.txt`,内容为该音频的转写文本 + +## 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|---|---|---:|---| +| doDenoise | switch | false | 是否启用 GTCRN 降噪 | +| denoiseModelPath | input | /models/AudioOperations/gtcrn/gtcrn.onnx | GTCRN ONNX 模型绝对路径 | +| doAnomalyFilter | switch | true | 是否启用异常语音检测与过滤 | +| minDur | inputNumber | 1.0 | 最小时长(秒) | +| maxDur | inputNumber | 20000.0 | 最大时长(秒) | +| silenceRatioTh | slider | 0.8 | 静音帧比例阈值(0~1) | +| silenceRmsRatioTh | slider | 0.05 | 静音判定阈值比例 | +| lidModelSource | input | /models/AudioOperations/lid/speechbrain_lang-id-voxlingua107-ecapa | SpeechBrain LID 本地模型目录 | +| lidDevice | select | cpu | LID 推理设备(cpu/cuda/npu) | +| lidMaxSeconds | inputNumber | 3.0 | LID 只取前 N 秒,0=全长 | +| maxSegmentSeconds | inputNumber | 120 | 切分最大秒数 | +| asrDevice | select | npu | ASR 设备参数(npu/cpu/auto) | +| doKeywordRecall | switch | false | 是否在 ASR 后计算关键词召回率 | +| referencePath | input | /dataset/{dataset_id}/references | 参考文件或参考目录路径;写入 `extraFilePath` 供后续评估算子读取,路径不存在会回退 | +| zhKeywordPath | input | /dataset/{dataset_id}/references/zh_keyword.txt | 中文关键词文件;不存在时优先从 `referencePath/extraFilePath` 找 `zh_keyword.txt` | +| enKeywordPath | input | /dataset/{dataset_id}/references/en_keyword.txt | 英文关键词文件;不存在时优先从 `referencePath/extraFilePath` 找 `en_keyword.txt` | +| keepKeywordDetails | switch | false | 是否将逐句 hit/miss 明细写入 `ext_params` | + +## 输入输出 + +- **输入**:`sample["filePath"]`(音频文件路径) +- **输出**: + - `sample["text"]`:当前输入音频对应的转写文本,并导出为 `.txt` + - `sample["ext_params"]["audio_asr"]`: + - `lang`:LID 结果(zh/en) + - `artifacts`:中间产物路径(normalized/denoise/lid/split/asr/merged_text) + - `reference`:填写 `referencePath` 后记录参考资源路径,并传给后续评估算子 + - `keyword_recall`:启用 `doKeywordRecall` 后写入中英文关键词召回率、样本数与报告路径,报告位于 `audio_reports/asr_pipeline/<文件名>/keyword_recall.txt` + +## 依赖说明 + +- **Python 依赖**(按启用功能而定): + - normalization/切分:`pydub`、`soundfile`、`numpy` + - LID:`torch`、`torchaudio`、`speechbrain` + - 降噪:`onnxruntime`(以及 GTCRN 模型文件) +- **系统依赖**: + - `pydub` 通常需要 `ffmpeg` +- **关键词召回率**: + - 使用纯 Python 文本处理,不额外依赖模型 + +## 版本历史 + +- **v1.0.0**:首次发布,支持音频标准化/(可选)降噪/过滤/LID/切分/ASR/合并 +- **v1.1.0**:同步 `audio_preprocessor` 关键词召回率能力,支持可选中英文关键词召回率评估 diff --git a/runtime/ops/mapper/audio_asr_pipeline/__init__.py b/runtime/ops/mapper/audio_asr_pipeline/__init__.py new file mode 100644 index 000000000..9d54df284 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- + +from datamate.core.base_op import OPERATORS + +OPERATORS.register_module(module_name='AudioAsrPipeline', + module_path="ops.mapper.audio_asr_pipeline.process") diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/audio_config.yaml b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/audio_config.yaml new file mode 100644 index 000000000..ac4498e96 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/audio_config.yaml @@ -0,0 +1,8 @@ +audio_config: + # audio_config.yaml - 音频格式化配置 + output_format: "wav" + channels: 1 + sample_rate: 16000 + sample_width: 2 + encoding: "pcm_s16le" + input_format: ["mp3", "wav", "aac", "m4a", "flac"] \ No newline at end of file diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/eval_wer.yaml b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/eval_wer.yaml new file mode 100644 index 000000000..8d48be936 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/eval_wer.yaml @@ -0,0 +1,6 @@ +eval_wer: + zh_ref: "input_data/validation/zh_transcript.txt" + en_ref: "input_data/validation/en_transcript.txt" + hyp: "output_data/asr/merged_text.txt" + work_dir: "output_data/validation" + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/merge_asr_by_source.yaml b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/merge_asr_by_source.yaml new file mode 100644 index 000000000..17f2f5886 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/config/merge_asr_by_source.yaml @@ -0,0 +1,6 @@ +merge_asr_by_source: + list_file: "output_data/split/item_with_lang.list" + zh_text: "output_data/asr/zh/ctc_greedy_search/text" + en_text: "output_data/asr/en/ctc_greedy_search/text" + output: "output_data/asr/merged_text.txt" + diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/__init__.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/__init__.py new file mode 100644 index 000000000..483df895e --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/__init__.py @@ -0,0 +1,71 @@ +"""Comprehensive speech processing toolkit""" + +import os + +# For redirect of HF transformers +import speechbrain.lobes.models # noqa: F401 + +from .core import Brain, Stage, create_experiment_directory +from .utils.importutils import deprecated_redirect, lazy_export_all +from .utils.run_opts import RunOptions + +with open( + os.path.join(os.path.dirname(__file__), "version.txt"), encoding="utf-8" +) as f: + version = f.read().strip() + +# Create an alias to the refactored function +parse_arguments = RunOptions.from_command_line_args + +__all__ = [ + "Stage", + "Brain", + "create_experiment_directory", + "parse_arguments", +] + +__version__ = version + + +deprecations = { + "speechbrain.k2_integration": "speechbrain.integrations.k2_fsa", + "speechbrain.wordemb": "speechbrain.integrations.huggingface.wordemb", + "speechbrain.lobes.models.huggingface_transformers": "speechbrain.integrations.huggingface", + "speechbrain.lobes.models.spacy": "speechbrain.integrations.nlp", + "speechbrain.lobes.models.flair": "speechbrain.integrations.nlp", +} + + +def make_deprecated_redirections(): + sb1_0_redirect_str = ( + "This is a change from SpeechBrain 1.0. " + "See: https://github.com/speechbrain/speechbrain/releases/tag/v1.0.0" + ) + + deprecated_redirect( + "speechbrain.pretrained", + "speechbrain.inference", + extra_reason=sb1_0_redirect_str, + also_lazy_export=True, + ) + + for old_path, new_path in deprecations.items(): + deprecated_redirect(old_path, new_path, also_lazy_export=True) + + # speechbrain.nnet.loss is not yet loaded at this point, so we cannot use + # also_lazy_export (it would try to access sys.modules['speechbrain.nnet.loss']). + # The sys.modules redirect alone is sufficient for import compatibility. + deprecated_redirect( + "speechbrain.nnet.loss.transducer_loss", + "speechbrain.integrations.numba.transducer_loss", + extra_reason=( + "This module depends on the optional 'numba' package. " + "If you encounter an ImportError here, please install numba, " + "for example with: pip install numba" + ), + ) + + +make_deprecated_redirections() + +lazy_export_all(__file__, __name__, export_subpackages=True) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/alignment/__init__.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/alignment/__init__.py new file mode 100644 index 000000000..e44e4c84a --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/alignment/__init__.py @@ -0,0 +1 @@ +"""Tools for aligning transcripts and speech signals""" diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/alignment/aligner.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/alignment/aligner.py new file mode 100644 index 000000000..1287c507d --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/alignment/aligner.py @@ -0,0 +1,1494 @@ +""" +Alignment code + +Authors + * Elena Rastorgueva 2020 + * Loren Lugosch 2020 +""" + +import random + +import torch + +from speechbrain.utils.checkpoints import ( + mark_as_loader, + mark_as_saver, + register_checkpoint_hooks, +) +from speechbrain.utils.data_utils import undo_padding + + +@register_checkpoint_hooks +class HMMAligner(torch.nn.Module): + """This class calculates Viterbi alignments in the forward method. + + It also records alignments and creates batches of them for use + in Viterbi training. + + Arguments + --------- + states_per_phoneme : int + Number of hidden states to use per phoneme. + output_folder : str + It is the folder that the alignments will be stored in when + saved to disk. Not yet implemented. + neg_inf : float + The float used to represent a negative infinite log probability. + Using `-float("Inf")` tends to give numerical instability. + A number more negative than -1e5 also sometimes gave errors when + the `genbmm` library was used (currently not in use). (default: -1e5) + batch_reduction : string + One of "none", "sum" or "mean". + What kind of batch-level reduction to apply to the loss calculated + in the forward method. + input_len_norm : bool + Whether to normalize the loss in the forward method by the length of + the inputs. + target_len_norm : bool + Whether to normalize the loss in the forward method by the length of + the targets. + lexicon_path : string + The location of the lexicon. + + Example + ------- + >>> log_posteriors = torch.tensor( + ... [ + ... [ + ... [-1.0, -10.0, -10.0], + ... [-10.0, -1.0, -10.0], + ... [-10.0, -10.0, -1.0], + ... ], + ... [ + ... [-1.0, -10.0, -10.0], + ... [-10.0, -1.0, -10.0], + ... [-10.0, -10.0, -10.0], + ... ], + ... ] + ... ) + >>> lens = torch.tensor([1.0, 0.66]) + >>> phns = torch.tensor([[0, 1, 2], [0, 1, 0]]) + >>> phn_lens = torch.tensor([1.0, 0.66]) + >>> aligner = HMMAligner() + >>> forward_scores = aligner( + ... log_posteriors, lens, phns, phn_lens, "forward" + ... ) + >>> forward_scores.shape + torch.Size([2]) + >>> viterbi_scores, alignments = aligner( + ... log_posteriors, lens, phns, phn_lens, "viterbi" + ... ) + >>> alignments + [[0, 1, 2], [0, 1]] + >>> viterbi_scores.shape + torch.Size([2]) + """ + + def __init__( + self, + states_per_phoneme=1, + output_folder="", + neg_inf=-1e5, + batch_reduction="none", + input_len_norm=False, + target_len_norm=False, + lexicon_path=None, + ): + super().__init__() + self.states_per_phoneme = states_per_phoneme + self.output_folder = output_folder + self.neg_inf = neg_inf + + self.batch_reduction = batch_reduction + self.input_len_norm = input_len_norm + self.target_len_norm = target_len_norm + + self.align_dict = {} + self.lexicon_path = lexicon_path + + if self.lexicon_path is not None: + with open(self.lexicon_path, encoding="utf-8") as f: + lines = f.readlines() + + for i, line in enumerate(lines): + if line[0] != ";": + start_index = i + break + + lexicon = {} # {"read": {0: "r eh d", 1: "r iy d"}} + lexicon_phones = set() + for i in range(start_index, len(lines)): + line = lines[i] + word = line.split()[0] + phones = line.split("/")[1] + + phones = "".join([p for p in phones if not p.isdigit()]) + + for p in phones.split(" "): + lexicon_phones.add(p) + + if "~" in word: + word = word.split("~")[0] + if word in lexicon: + number_of_existing_pronunciations = len(lexicon[word]) + lexicon[word][number_of_existing_pronunciations] = phones + else: + lexicon[word] = {0: phones} + self.lexicon = lexicon + + lexicon_phones = list(lexicon_phones) + lexicon_phones.sort() + + self.lex_lab2ind = {p: i + 1 for i, p in enumerate(lexicon_phones)} + self.lex_ind2lab = {i + 1: p for i, p in enumerate(lexicon_phones)} + + # add sil, which is not in the lexicon + self.lex_lab2ind["sil"] = 0 + self.lex_ind2lab[0] = "sil" + + def _use_lexicon(self, words, interword_sils, sample_pron): + """Do processing using the lexicon to return a sequence of the possible + phonemes, the transition/pi probabilities, and the possible final states. + Inputs correspond to a single utterance, not a whole batch. + + Arguments + --------- + words : list + List of the words in the transcript. + interword_sils : bool + If True, optional silences will be inserted between every word. + If False, optional silences will only be placed at the beginning + and end of each utterance. + sample_pron : bool + If True, it will sample a single possible sequence of phonemes. + If False, it will return statistics for all possible sequences of + phonemes. + + Returns + ------- + poss_phns : torch.Tensor (phoneme) + The phonemes that are thought to be in each utterance. + log_transition_matrix : torch.Tensor (batch, from, to) + Tensor containing transition (log) probabilities. + start_states : list of ints + A list of the possible starting states in each utterance. + final_states : list of ints + A list of the possible final states for each utterance. + """ + + number_of_states = 0 + words_prime = [] # This will contain one "word" for each optional silence and pronunciation. + # structure of each "word_prime": + # [word index, [[state sequence 1], [state sequence 2]], ] + word_index = 0 + phoneme_indices = [] + for word in words: + if word_index == 0 or interword_sils is True: + # optional silence + word_prime = [ + word_index, + [ + [ + number_of_states + i + for i in range(self.states_per_phoneme) + ] + ], + True, + ] + words_prime.append(word_prime) + phoneme_indices += [ + self.silence_index * self.states_per_phoneme + i + for i in range(self.states_per_phoneme) + ] + number_of_states += self.states_per_phoneme + word_index += 1 + + # word + word_prime = [word_index, [], False] + if sample_pron and len(self.lexicon[word]) > 1: + random.shuffle(self.lexicon[word]) + for pron_idx in range(len(self.lexicon[word])): + pronunciation = self.lexicon[word][pron_idx] + phonemes = pronunciation.split() + word_prime[1].append([]) + for p in phonemes: + phoneme_indices += [ + self.lex_lab2ind[p] * self.states_per_phoneme + i + for i in range(self.states_per_phoneme) + ] + word_prime[1][pron_idx] += [ + number_of_states + i + for i in range(self.states_per_phoneme) + ] + number_of_states += self.states_per_phoneme + if sample_pron: + break + + words_prime.append(word_prime) + word_index += 1 + # optional final silence + word_prime = [ + word_index, + [[number_of_states + i for i in range(self.states_per_phoneme)]], + True, + ] + words_prime.append(word_prime) + phoneme_indices += [ + self.silence_index * self.states_per_phoneme + i + for i in range(self.states_per_phoneme) + ] + number_of_states += self.states_per_phoneme + word_index += 1 + + transition_matrix = 1.0 * torch.eye( + number_of_states + ) # diagonal = all states have a self-loop + final_states = [] + for word_prime in words_prime: + word_idx = word_prime[0] + is_optional_silence = word_prime[-1] + next_word_exists = word_idx < len(words_prime) - 2 + this_word_last_states = [ + word_prime[1][i][-1] for i in range(len(word_prime[1])) + ] + + # create transitions to next state from previous state within each pronunciation + for pronunciation in word_prime[1]: + for state_idx in range(len(pronunciation) - 1): + state = pronunciation[state_idx] + next_state = pronunciation[state_idx + 1] + transition_matrix[state, next_state] = 1.0 + + # create transitions to next word's starting states + if next_word_exists: + if is_optional_silence or not interword_sils: + next_word_idx = word_idx + 1 + else: + next_word_idx = word_idx + 2 + next_word_starting_states = [ + words_prime[next_word_idx][1][i][0] + for i in range(len(words_prime[next_word_idx][1])) + ] + + for this_word_last_state in this_word_last_states: + for next_word_starting_state in next_word_starting_states: + transition_matrix[ + this_word_last_state, next_word_starting_state + ] = 1.0 + + else: + final_states += this_word_last_states + + if not is_optional_silence: + next_silence_idx = word_idx + 1 + next_silence_starting_state = words_prime[next_silence_idx][1][ + 0 + ][0] + for this_word_last_state in this_word_last_states: + transition_matrix[ + this_word_last_state, next_silence_starting_state + ] = 1.0 + + log_transition_matrix = transition_matrix.log().log_softmax(1) + + start_states = [words_prime[0][1][0][0]] + start_states += [ + words_prime[1][1][i][0] for i in range(len(words_prime[1][1])) + ] + + poss_phns = torch.tensor(phoneme_indices) + + return poss_phns, log_transition_matrix, start_states, final_states + + def use_lexicon(self, words, interword_sils=True, sample_pron=False): + """Do processing using the lexicon to return a sequence of the possible + phonemes, the transition/pi probabilities, and the possible final + states. + Does processing on an utterance-by-utterance basis. Each utterance + in the batch is processed by a helper method `_use_lexicon`. + + Arguments + --------- + words : list + List of the words in the transcript + interword_sils : bool + If True, optional silences will be inserted between every word. + If False, optional silences will only be placed at the beginning + and end of each utterance. + sample_pron: bool + If True, it will sample a single possible sequence of phonemes. + If False, it will return statistics for all possible sequences of + phonemes. + + Returns + ------- + poss_phns: torch.Tensor (batch, phoneme in possible phn sequence) + The phonemes that are thought to be in each utterance. + poss_phn_lens: torch.Tensor (batch) + The relative length of each possible phoneme sequence in the batch. + trans_prob: torch.Tensor (batch, from, to) + Tensor containing transition (log) probabilities. + pi_prob: torch.Tensor (batch, state) + Tensor containing initial (log) probabilities. + final_state: list of lists of ints + A list of lists of possible final states for each utterance. + + Example + ------- + >>> aligner = HMMAligner() + >>> aligner.lexicon = {"a": {0: "a"}, "b": {0: "b", 1: "c"}} + >>> words = [["a", "b"]] + >>> aligner.lex_lab2ind = { + ... "sil": 0, + ... "a": 1, + ... "b": 2, + ... "c": 3, + ... } + >>> poss_phns, poss_phn_lens, trans_prob, pi_prob, final_states = ( + ... aligner.use_lexicon(words, interword_sils=True) + ... ) + >>> poss_phns + tensor([[0, 1, 0, 2, 3, 0]]) + >>> poss_phn_lens + tensor([1.]) + >>> trans_prob + tensor([[[-6.9315e-01, -6.9315e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05, + -1.0000e+05], + [-1.0000e+05, -1.3863e+00, -1.3863e+00, -1.3863e+00, -1.3863e+00, + -1.0000e+05], + [-1.0000e+05, -1.0000e+05, -1.0986e+00, -1.0986e+00, -1.0986e+00, + -1.0000e+05], + [-1.0000e+05, -1.0000e+05, -1.0000e+05, -6.9315e-01, -1.0000e+05, + -6.9315e-01], + [-1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -6.9315e-01, + -6.9315e-01], + [-1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, + 0.0000e+00]]]) + >>> pi_prob + tensor([[-6.9315e-01, -6.9315e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05, + -1.0000e+05]]) + >>> final_states + [[3, 4, 5]] + >>> # With no optional silences between words + >>> poss_phns_, _, trans_prob_, pi_prob_, final_states_ = ( + ... aligner.use_lexicon(words, interword_sils=False) + ... ) + >>> poss_phns_ + tensor([[0, 1, 2, 3, 0]]) + >>> trans_prob_ + tensor([[[-6.9315e-01, -6.9315e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05], + [-1.0000e+05, -1.0986e+00, -1.0986e+00, -1.0986e+00, -1.0000e+05], + [-1.0000e+05, -1.0000e+05, -6.9315e-01, -1.0000e+05, -6.9315e-01], + [-1.0000e+05, -1.0000e+05, -1.0000e+05, -6.9315e-01, -6.9315e-01], + [-1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, 0.0000e+00]]]) + >>> pi_prob_ + tensor([[-6.9315e-01, -6.9315e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05]]) + >>> final_states_ + [[2, 3, 4]] + >>> # With sampling of a single possible pronunciation + >>> import random + >>> random.seed(0) + >>> poss_phns_, _, trans_prob_, pi_prob_, final_states_ = ( + ... aligner.use_lexicon(words, sample_pron=True) + ... ) + >>> poss_phns_ + tensor([[0, 1, 0, 2, 0]]) + >>> trans_prob_ + tensor([[[-6.9315e-01, -6.9315e-01, -1.0000e+05, -1.0000e+05, -1.0000e+05], + [-1.0000e+05, -1.0986e+00, -1.0986e+00, -1.0986e+00, -1.0000e+05], + [-1.0000e+05, -1.0000e+05, -6.9315e-01, -6.9315e-01, -1.0000e+05], + [-1.0000e+05, -1.0000e+05, -1.0000e+05, -6.9315e-01, -6.9315e-01], + [-1.0000e+05, -1.0000e+05, -1.0000e+05, -1.0000e+05, 0.0000e+00]]]) + """ + self.silence_index = self.lex_lab2ind["sil"] + + poss_phns = [] + trans_prob = [] + start_states = [] + final_states = [] + + for words_ in words: + ( + poss_phns_, + trans_prob_, + start_states_, + final_states_, + ) = self._use_lexicon(words_, interword_sils, sample_pron) + poss_phns.append(poss_phns_) + trans_prob.append(trans_prob_) + start_states.append(start_states_) + final_states.append(final_states_) + + # pad poss_phns, trans_prob with 0 to have same length + poss_phn_lens = [len(poss_phns_) for poss_phns_ in poss_phns] + U_max = max(poss_phn_lens) + + batch_size = len(poss_phns) + for index in range(batch_size): + phn_pad_length = U_max - len(poss_phns[index]) + poss_phns[index] = torch.nn.functional.pad( + poss_phns[index], (0, phn_pad_length), value=0 + ) + trans_prob[index] = torch.nn.functional.pad( + trans_prob[index], + (0, phn_pad_length, 0, phn_pad_length), + value=self.neg_inf, + ) + + # Stack into single tensor + poss_phns = torch.stack(poss_phns) + trans_prob = torch.stack(trans_prob) + trans_prob[trans_prob == -float("Inf")] = self.neg_inf + + # make pi prob + pi_prob = self.neg_inf * torch.ones([batch_size, U_max]) + for start_state in start_states: + pi_prob[:, start_state] = 1 + + pi_prob = torch.nn.functional.log_softmax(pi_prob, dim=1) + + # Convert poss_phn_lens from absolute to relative lengths + poss_phn_lens = torch.tensor(poss_phn_lens).float() / U_max + return poss_phns, poss_phn_lens, trans_prob, pi_prob, final_states + + def _make_pi_prob(self, phn_lens_abs): + """Creates tensor of initial (log) probabilities (known as 'pi'). + Assigns all probability mass to the first phoneme in the sequence. + + Arguments + --------- + phn_lens_abs : torch.Tensor (batch) + The absolute length of each phoneme sequence in the batch. + + Returns + ------- + pi_prob : torch.Tensor (batch, phn) + """ + batch_size = len(phn_lens_abs) + U_max = int(phn_lens_abs.max()) + + pi_prob = self.neg_inf * torch.ones([batch_size, U_max]) + pi_prob[:, 0] = 0 + + return pi_prob + + def _make_trans_prob(self, phn_lens_abs): + """Creates tensor of transition (log) probabilities. + Only allows transitions to the same phoneme (self-loop) or the next + phoneme in the phn sequence + + Arguments + --------- + phn_lens_abs : torch.Tensor (batch) + The absolute length of each phoneme sequence in the batch. + + Returns + ------- + trans_prob : torch.Tensor (batch, from, to) + """ + # Extract useful values for later + batch_size = len(phn_lens_abs) + U_max = int(phn_lens_abs.max()) + device = phn_lens_abs.device + + ## trans_prob matrix consists of 2 diagonals: + ## (1) offset diagonal (next state) & + ## (2) main diagonal (self-loop) + # make offset diagonal + trans_prob_off_diag = torch.eye(U_max - 1) + zero_side = torch.zeros([U_max - 1, 1]) + zero_bottom = torch.zeros([1, U_max]) + trans_prob_off_diag = torch.cat((zero_side, trans_prob_off_diag), 1) + trans_prob_off_diag = torch.cat((trans_prob_off_diag, zero_bottom), 0) + + # make main diagonal + trans_prob_main_diag = torch.eye(U_max) + + # join the diagonals and repeat for whole batch + trans_prob = trans_prob_off_diag + trans_prob_main_diag + trans_prob = ( + trans_prob.reshape(1, U_max, U_max) + .repeat(batch_size, 1, 1) + .to(device) + ) + + # clear probabilities for too-long sequences + mask_a = ( + torch.arange(U_max, device=device)[None, :] < phn_lens_abs[:, None] + ) + mask_a = mask_a.unsqueeze(2) + mask_a = mask_a.expand(-1, -1, U_max) + mask_b = mask_a.permute(0, 2, 1) + trans_prob = trans_prob * (mask_a & mask_b).float() + + ## put -infs in place of zeros: + trans_prob = torch.where( + trans_prob == 1, + trans_prob, + torch.tensor(-float("Inf"), device=device), + ) + + ## normalize + trans_prob = torch.nn.functional.log_softmax(trans_prob, dim=2) + + ## set nans to v neg numbers + trans_prob[trans_prob != trans_prob] = self.neg_inf + ## set -infs to v neg numbers + trans_prob[trans_prob == -float("Inf")] = self.neg_inf + + return trans_prob + + def _make_emiss_pred_useful( + self, emission_pred, lens_abs, phn_lens_abs, phns + ): + """Creates a 'useful' form of the posterior probabilities, rearranged + into the order of phoneme appearance in phns. + + Arguments + --------- + emission_pred : torch.Tensor (batch, time, phoneme in vocabulary) + posterior probabilities from our acoustic model + lens_abs : torch.Tensor (batch) + The absolute length of each input to the acoustic model, + i.e., the number of frames. + phn_lens_abs : torch.Tensor (batch) + The absolute length of each phoneme sequence in the batch. + phns : torch.Tensor (batch, phoneme in phn sequence) + The phonemes that are known/thought to be in each utterance. + + Returns + ------- + emiss_pred_useful : torch.Tensor + Tensor shape (batch, phoneme in phn sequence, time). + """ + # Extract useful values for later + U_max = int(phn_lens_abs.max().item()) + fb_max_length = int(lens_abs.max().item()) + device = emission_pred.device + + # apply mask based on lens_abs + mask_lens = ( + torch.arange(fb_max_length).to(device)[None, :] < lens_abs[:, None] + ) + + emiss_pred_acc_lens = torch.where( + mask_lens[:, :, None], + emission_pred, + torch.tensor([0.0], device=device), + ) + + # manipulate phn tensor, and then 'torch.gather' + phns = phns.to(device) + phns_copied = phns.unsqueeze(1).expand(-1, fb_max_length, -1) + emiss_pred_useful = torch.gather(emiss_pred_acc_lens, 2, phns_copied) + + # apply mask based on phn_lens_abs + mask_phn_lens = ( + torch.arange(U_max).to(device)[None, :] < phn_lens_abs[:, None] + ) + emiss_pred_useful = torch.where( + mask_phn_lens[:, None, :], + emiss_pred_useful, + torch.tensor([self.neg_inf], device=device), + ) + + emiss_pred_useful = emiss_pred_useful.permute(0, 2, 1) + + return emiss_pred_useful + + def _dp_forward( + self, + pi_prob, + trans_prob, + emiss_pred_useful, + lens_abs, + phn_lens_abs, + phns, + ): + """Does forward dynamic programming algorithm. + + Arguments + --------- + pi_prob : torch.Tensor (batch, phn) + Tensor containing initial (log) probabilities. + trans_prob : torch.Tensor (batch, from, to) + Tensor containing transition (log) probabilities. + emiss_pred_useful : torch.Tensor (batch, phoneme in phn sequence, time) + A 'useful' form of the posterior probabilities, rearranged + into the order of phoneme appearance in phns. + lens_abs : torch.Tensor (batch) + The absolute length of each input to the acoustic model, + i.e., the number of frames. + phn_lens_abs : torch.Tensor (batch) + The absolute length of each phoneme sequence in the batch. + phns : torch.Tensor (batch, phoneme in phn sequence) + The phonemes that are known/thought to be in each utterance. + + Returns + ------- + sum_alpha_T : torch.Tensor (batch) + The (log) likelihood of each utterance in the batch. + """ + # useful values + batch_size = len(phn_lens_abs) + U_max = phn_lens_abs.max() + fb_max_length = lens_abs.max() + device = emiss_pred_useful.device + + pi_prob = pi_prob.to(device) + trans_prob = trans_prob.to(device) + + # initialise + alpha_matrix = self.neg_inf * torch.ones( + [batch_size, U_max, fb_max_length], device=device + ) + alpha_matrix[:, :, 0] = pi_prob + emiss_pred_useful[:, :, 0] + + for t in range(1, fb_max_length): + utt_lens_passed = lens_abs < t + + if True in utt_lens_passed: + n_passed = utt_lens_passed.sum() + I_tensor = self.neg_inf * torch.ones(n_passed, U_max, U_max) + I_tensor[:, torch.arange(U_max), torch.arange(U_max)] = 0.0 + I_tensor = I_tensor.to(device) + + trans_prob[utt_lens_passed] = I_tensor + + alpha_times_trans = batch_log_matvecmul( + trans_prob.permute(0, 2, 1), alpha_matrix[:, :, t - 1] + ) + alpha_matrix[:, :, t] = ( + alpha_times_trans + emiss_pred_useful[:, :, t] + ) + + sum_alpha_T = torch.logsumexp( + alpha_matrix[torch.arange(batch_size), :, -1], dim=1 + ) + + return sum_alpha_T + + def _dp_viterbi( + self, + pi_prob, + trans_prob, + emiss_pred_useful, + lens_abs, + phn_lens_abs, + phns, + final_states, + ): + """Calculates Viterbi alignment using dynamic programming. + + Arguments + --------- + pi_prob : torch.Tensor (batch, phn) + Tensor containing initial (log) probabilities. + trans_prob : torch.Tensor (batch, from, to) + Tensor containing transition (log) probabilities. + emiss_pred_useful : torch.Tensor (batch, phoneme in phn sequence, time) + A 'useful' form of the posterior probabilities, rearranged + into the order of phoneme appearance in phns. + lens_abs : torch.Tensor (batch) + The absolute length of each input to the acoustic model, + i.e., the number of frames. + phn_lens_abs : torch.Tensor (batch) + The absolute length of each phoneme sequence in the batch. + phns : torch.Tensor (batch, phoneme in phn sequence) + The phonemes that are known/thought to be in each utterance. + final_states : list + List of final states + + Returns + ------- + z_stars : list of lists of int + Viterbi alignments for the files in the batch. + z_stars_loc : list of lists of int + The locations of the Viterbi alignments for the files in the batch. + e.g., for a batch with a single utterance with 5 phonemes, + `z_stars_loc` will look like: + [[0, 0, 0, 1, 1, 2, 3, 3, 3, 4, 4]]. + viterbi_scores : torch.Tensor (batch) + The (log) likelihood of the Viterbi path for each utterance. + """ + + # useful values + batch_size = len(phn_lens_abs) + U_max = phn_lens_abs.max() + fb_max_length = lens_abs.max() + device = emiss_pred_useful.device + + pi_prob = pi_prob.to(device) + trans_prob = trans_prob.to(device) + + v_matrix = self.neg_inf * torch.ones( + [batch_size, U_max, fb_max_length], device=device + ) + backpointers = -99 * torch.ones( + [batch_size, U_max, fb_max_length], device=device + ) + + # initialise + v_matrix[:, :, 0] = pi_prob + emiss_pred_useful[:, :, 0] + + for t in range(1, fb_max_length): + x, argmax = batch_log_maxvecmul( + trans_prob.permute(0, 2, 1), v_matrix[:, :, t - 1] + ) + v_matrix[:, :, t] = x + emiss_pred_useful[:, :, t] + + backpointers[:, :, t] = argmax.type(dtype=torch.float32) + + z_stars = [] + z_stars_loc = [] + + for utterance_in_batch in range(batch_size): + len_abs = lens_abs[utterance_in_batch] + + if final_states is not None: + final_states_utter = final_states[utterance_in_batch] + # Pick most probable of the final states + viterbi_finals = v_matrix[ + utterance_in_batch, final_states_utter, len_abs - 1 + ] + final_state_chosen = torch.argmax(viterbi_finals).item() + U = final_states_utter[final_state_chosen] + else: + U = phn_lens_abs[utterance_in_batch].long().item() - 1 + + z_star_i_loc = [U] + z_star_i = [phns[utterance_in_batch, z_star_i_loc[0]].item()] + for time_step in range(len_abs, 1, -1): + current_best_loc = z_star_i_loc[0] + + earlier_best_loc = ( + backpointers[ + utterance_in_batch, current_best_loc, time_step - 1 + ] + .long() + .item() + ) + earlier_z_star = phns[ + utterance_in_batch, earlier_best_loc + ].item() + + z_star_i_loc.insert(0, earlier_best_loc) + z_star_i.insert(0, earlier_z_star) + z_stars.append(z_star_i) + z_stars_loc.append(z_star_i_loc) + + # picking out viterbi_scores + viterbi_scores = v_matrix[ + torch.arange(batch_size), phn_lens_abs - 1, lens_abs - 1 + ] + + return z_stars, z_stars_loc, viterbi_scores + + def _loss_reduction(self, loss, input_lens, target_lens): + """Applies reduction to loss as specified during object initialization. + + Arguments + --------- + loss : torch.Tensor (batch) + The loss tensor to be reduced. + input_lens : torch.Tensor (batch) + The absolute durations of the inputs. + target_lens : torch.Tensor (batch) + The absolute durations of the targets. + + Returns + ------- + loss : torch.Tensor (batch, or scalar) + The loss with reduction applied if it is specified. + + """ + if self.input_len_norm is True: + loss = torch.div(loss, input_lens) + + if self.target_len_norm is True: + loss = torch.div(loss, target_lens) + + if self.batch_reduction == "none": + pass + elif self.batch_reduction == "sum": + loss = loss.sum() + elif self.batch_reduction == "mean": + loss = loss.mean() + else: + raise ValueError( + "`batch_reduction` parameter must be one of 'none', 'sum' or 'mean'" + ) + + return loss + + def forward( + self, + emission_pred, + lens, + phns, + phn_lens, + dp_algorithm, + prob_matrices=None, + ): + """Prepares relevant (log) probability tensors and does dynamic + programming: either the forward or the Viterbi algorithm. Applies + reduction as specified during object initialization. + + Arguments + --------- + emission_pred : torch.Tensor (batch, time, phoneme in vocabulary) + Posterior probabilities from our acoustic model. + lens : torch.Tensor (batch) + The relative duration of each utterance sound file. + phns : torch.Tensor (batch, phoneme in phn sequence) + The phonemes that are known/thought to be in each utterance + phn_lens : torch.Tensor (batch) + The relative length of each phoneme sequence in the batch. + dp_algorithm : string + Either "forward" or "viterbi". + prob_matrices : dict + (Optional) Must contain keys 'trans_prob', 'pi_prob' and 'final_states'. + Used to override the default forward and viterbi operations which + force traversal over all of the states in the `phns` sequence. + + Returns + ------- + tensor + + (1) if dp_algorithm == "forward". + + ``forward_scores`` : torch.Tensor (batch, or scalar) + + The (log) likelihood of each utterance in the batch, with reduction + applied if specified. (OR) + + (2) if dp_algorithm == "viterbi". + + ``viterbi_scores`` : torch.Tensor (batch, or scalar) + + The (log) likelihood of the Viterbi path for each utterance, with + reduction applied if specified. + + ``alignments`` : list of lists of int + + Viterbi alignments for the files in the batch. + """ + + lens_abs = torch.round(emission_pred.shape[1] * lens).long() + phn_lens_abs = torch.round(phns.shape[1] * phn_lens).long() + phns = phns.long() + + if prob_matrices is None: + pi_prob = self._make_pi_prob(phn_lens_abs) + trans_prob = self._make_trans_prob(phn_lens_abs) + final_states = None + else: + if ( + ("pi_prob" in prob_matrices) + and ("trans_prob" in prob_matrices) + and ("final_states" in prob_matrices) + ): + pi_prob = prob_matrices["pi_prob"] + trans_prob = prob_matrices["trans_prob"] + final_states = prob_matrices["final_states"] + else: + raise ValueError( + """`prob_matrices` must contain the keys + `pi_prob`, `trans_prob` and `final_states`""" + ) + + emiss_pred_useful = self._make_emiss_pred_useful( + emission_pred, lens_abs, phn_lens_abs, phns + ) + + if dp_algorithm == "forward": + # do forward training + forward_scores = self._dp_forward( + pi_prob, + trans_prob, + emiss_pred_useful, + lens_abs, + phn_lens_abs, + phns, + ) + + forward_scores = self._loss_reduction( + forward_scores, lens_abs, phn_lens_abs + ) + + return forward_scores + + elif dp_algorithm == "viterbi": + alignments, _, viterbi_scores = self._dp_viterbi( + pi_prob, + trans_prob, + emiss_pred_useful, + lens_abs, + phn_lens_abs, + phns, + final_states, + ) + + viterbi_scores = self._loss_reduction( + viterbi_scores, lens_abs, phn_lens_abs + ) + + return viterbi_scores, alignments + + else: + raise ValueError( + "dp_algorithm input must be either 'forward' or 'viterbi'" + ) + + def expand_phns_by_states_per_phoneme(self, phns, phn_lens): + """Expands each phoneme in the phn sequence by the number of hidden + states per phoneme defined in the HMM. + + Arguments + --------- + phns : torch.Tensor (batch, phoneme in phn sequence) + The phonemes that are known/thought to be in each utterance. + phn_lens : torch.Tensor (batch) + The relative length of each phoneme sequence in the batch. + + Returns + ------- + expanded_phns : torch.Tensor (batch, phoneme in expanded phn sequence) + + Example + ------- + >>> phns = torch.tensor([[0.0, 3.0, 5.0, 0.0], [0.0, 2.0, 0.0, 0.0]]) + >>> phn_lens = torch.tensor([1.0, 0.75]) + >>> aligner = HMMAligner(states_per_phoneme=3) + >>> expanded_phns = aligner.expand_phns_by_states_per_phoneme( + ... phns, phn_lens + ... ) + >>> expanded_phns + tensor([[ 0., 1., 2., 9., 10., 11., 15., 16., 17., 0., 1., 2.], + [ 0., 1., 2., 6., 7., 8., 0., 1., 2., 0., 0., 0.]]) + """ + # Initialise expanded_phns + expanded_phns = torch.zeros( + phns.shape[0], phns.shape[1] * self.states_per_phoneme + ) + expanded_phns = expanded_phns.to(phns.device) + + phns = undo_padding(phns, phn_lens) + for i, phns_utt in enumerate(phns): + expanded_phns_utt = [] + for phoneme_index in phns_utt: + expanded_phns_utt += [ + self.states_per_phoneme * phoneme_index + i_ + for i_ in range(self.states_per_phoneme) + ] + + expanded_phns[i, : len(expanded_phns_utt)] = torch.tensor( + expanded_phns_utt + ) + return expanded_phns + + def store_alignments(self, ids, alignments): + """Records Viterbi alignments in `self.align_dict`. + + Arguments + --------- + ids : list of str + IDs of the files in the batch. + alignments : list of lists of int + Viterbi alignments for the files in the batch. + Without padding. + + Example + ------- + >>> aligner = HMMAligner() + >>> ids = ["id1", "id2"] + >>> alignments = [[0, 2, 4], [1, 2, 3, 4]] + >>> aligner.store_alignments(ids, alignments) + >>> aligner.align_dict.keys() + dict_keys(['id1', 'id2']) + >>> aligner.align_dict["id1"] + tensor([0, 2, 4], dtype=torch.int16) + """ + + for i, id in enumerate(ids): + alignment_i = alignments[i] + alignment_i = torch.tensor(alignment_i, dtype=torch.int16).cpu() + self.align_dict[id] = alignment_i + + def _get_flat_start_batch(self, lens_abs, phn_lens_abs, phns): + """Prepares flat start alignments (with zero padding) for every utterance + in the batch. + Every phoneme will have an equal duration, except for the final phoneme + potentially. E.g. if 104 frames and 10 phonemes, 9 phonemes will have + duration of 10 frames, and one phoneme will have a duration of 14 frames. + + Arguments + --------- + lens_abs : torch.Tensor (batch) + The absolute length of each input to the acoustic model, + i.e., the number of frames. + + phn_lens_abs : torch.Tensor (batch) + The absolute length of each phoneme sequence in the batch. + + phns : torch.Tensor (batch, phoneme in phn sequence) + The phonemes that are known/thought to be in each utterance. + + Returns + ------- + flat_start_batch : torch.Tensor (batch, time) + Flat start alignments for utterances in the batch, with zero padding. + """ + phns = phns.long() + + batch_size = len(lens_abs) + fb_max_length = torch.max(lens_abs) + + flat_start_batch = torch.zeros( + batch_size, fb_max_length, device=phns.device + ).long() + for i in range(batch_size): + utter_phns = phns[i] + utter_phns = utter_phns[: phn_lens_abs[i]] # crop out zero padding + repeat_amt = int(lens_abs[i].item() / len(utter_phns)) + + # make sure repeat_amt is at least 1. (the code above + # may make repeat_amt==0 if self.states_per_phoneme is too large). + if repeat_amt == 0: + repeat_amt = 1 + + # repeat each phoneme in utter_phns by repeat_amt + utter_phns = utter_phns.repeat_interleave(repeat_amt) + + # len(utter_phns) may be <, == or > lens_abs[i], so + # make sure len(utter_phns) == lens_abs[i] + utter_phns = utter_phns[: lens_abs[i]] + utter_phns = torch.nn.functional.pad( + utter_phns, + (0, int(lens_abs[i]) - len(utter_phns)), + value=utter_phns[-1], # pad out with final phoneme + ) + + flat_start_batch[i, : len(utter_phns)] = utter_phns + + return flat_start_batch + + def _get_viterbi_batch(self, ids, lens_abs): + """Retrieves Viterbi alignments stored in `self.align_dict` and + creates a batch of them, with zero padding. + + Arguments + --------- + ids : list of str + IDs of the files in the batch. + lens_abs : torch.Tensor (batch) + The absolute length of each input to the acoustic model, + i.e., the number of frames. + + Returns + ------- + viterbi_batch : torch.Tensor (batch, time) + The previously-recorded Viterbi alignments for the utterances + in the batch. + + """ + batch_size = len(lens_abs) + fb_max_length = torch.max(lens_abs) + + viterbi_batch = torch.zeros( + batch_size, fb_max_length, device=lens_abs.device + ).long() + for i in range(batch_size): + viterbi_preds = self.align_dict[ids[i]] + viterbi_preds = torch.nn.functional.pad( + viterbi_preds, (0, fb_max_length - len(viterbi_preds)) + ) + + viterbi_batch[i] = viterbi_preds.long() + + return viterbi_batch + + def get_prev_alignments(self, ids, emission_pred, lens, phns, phn_lens): + """Fetches previously recorded Viterbi alignments if they are available. + If not, fetches flat start alignments. + Currently, assumes that if a Viterbi alignment is not available for the + first utterance in the batch, it will not be available for the rest of + the utterances. + + Arguments + --------- + ids : list of str + IDs of the files in the batch. + emission_pred : torch.Tensor (batch, time, phoneme in vocabulary) + Posterior probabilities from our acoustic model. Used to infer the + duration of the longest utterance in the batch. + lens : torch.Tensor (batch) + The relative duration of each utterance sound file. + phns : torch.Tensor (batch, phoneme in phn sequence) + The phonemes that are known/thought to be in each utterance. + phn_lens : torch.Tensor (batch) + The relative length of each phoneme sequence in the batch. + + Returns + ------- + torch.Tensor (batch, time) + Zero-padded alignments. + + Example + ------- + >>> ids = ["id1", "id2"] + >>> emission_pred = torch.tensor( + ... [ + ... [ + ... [-1.0, -10.0, -10.0], + ... [-10.0, -1.0, -10.0], + ... [-10.0, -10.0, -1.0], + ... ], + ... [ + ... [-1.0, -10.0, -10.0], + ... [-10.0, -1.0, -10.0], + ... [-10.0, -10.0, -10.0], + ... ], + ... ] + ... ) + >>> lens = torch.tensor([1.0, 0.66]) + >>> phns = torch.tensor([[0, 1, 2], [0, 1, 0]]) + >>> phn_lens = torch.tensor([1.0, 0.66]) + >>> aligner = HMMAligner() + >>> alignment_batch = aligner.get_prev_alignments( + ... ids, emission_pred, lens, phns, phn_lens + ... ) + >>> alignment_batch + tensor([[0, 1, 2], + [0, 1, 0]]) + """ + + lens_abs = torch.round(emission_pred.shape[1] * lens).long() + phn_lens_abs = torch.round(phns.shape[1] * phn_lens).long() + + if ids[0] in self.align_dict: + return self._get_viterbi_batch(ids, lens_abs) + else: + return self._get_flat_start_batch(lens_abs, phn_lens_abs, phns) + + def _calc_accuracy_sent(self, alignments_, ends_, phns_): + """Calculates the accuracy between predicted alignments and ground truth + alignments for a single sentence/utterance. + + Arguments + --------- + alignments_ : list of ints + The predicted alignments for the utterance. + ends_ : list of ints + A list of the sample indices where each ground truth phoneme + ends, according to the transcription. + phns_ : list of ints + The unpadded list of ground truth phonemes in the utterance. + + Returns + ------- + mean_acc : float + The mean percentage of times that the upsampled predicted alignment + matches the ground truth alignment. + """ + # Create array containing the true alignment at each sample + ends_ = [0] + [int(end) for end in ends_] + true_durations = [ends_[i] - ends_[i - 1] for i in range(1, len(ends_))] + true_alignments = [] + + for i in range(len(phns_)): + true_alignments += [phns_[i]] * (true_durations[i]) + true_alignments = torch.tensor(true_alignments) + + # Upsample the predicted alignment array + # and make sure length matches that of `true_alignment` + upsample_factor = int( + torch.round(torch.tensor(len(true_alignments) / len(alignments_))) + ) + + alignments_ = torch.tensor(alignments_) + alignments_upsampled = alignments_.repeat_interleave(upsample_factor) + alignments_upsampled = alignments_upsampled[: len(true_alignments)] + + if len(true_alignments) > len(alignments_upsampled): + alignments_upsampled = torch.nn.functional.pad( + alignments_upsampled, + (0, len(true_alignments) - len(alignments_upsampled)), + ) + + # Measure sample-wise accuracy + accuracy = ( + alignments_upsampled == true_alignments + ).float().mean().item() * 100 + + return accuracy + + def calc_accuracy(self, alignments, ends, phns, ind2labs=None): + """Calculates mean accuracy between predicted alignments and ground truth + alignments. Ground truth alignments are derived from ground truth phns + and their ends in the audio sample. + + Arguments + --------- + alignments : list of lists of ints/floats + The predicted alignments for each utterance in the batch. + ends : list of lists of ints + A list of lists of sample indices where each ground truth phoneme + ends, according to the transcription. + Note: current implementation assumes that 'ends' mark the index + where the next phoneme begins. + phns : list of lists of ints/floats + The unpadded list of lists of ground truth phonemes in the batch. + ind2labs : tuple + (Optional) + Contains the original index-to-label dicts for the first and second + sequence of phonemes. + + Returns + ------- + mean_acc : float + The mean percentage of times that the upsampled predicted alignment + matches the ground truth alignment. + + Example + ------- + >>> aligner = HMMAligner() + >>> alignments = [[0.0, 0.0, 0.0, 1.0]] + >>> phns = [[0.0, 1.0]] + >>> ends = [[2, 4]] + >>> mean_acc = aligner.calc_accuracy(alignments, ends, phns) + >>> mean_acc.item() + 75.0 + """ + acc_hist = [] + + # Do conversion if states_per_phoneme > 1 + if self.states_per_phoneme > 1: + alignments = [ + [i // self.states_per_phoneme for i in utt] + for utt in alignments + ] + + # convert to common alphabet if need be + if ind2labs is not None: + alignments, phns = map_inds_to_intersect(alignments, phns, ind2labs) + + for alignments_, ends_, phns_ in zip(alignments, ends, phns): + acc = self._calc_accuracy_sent(alignments_, ends_, phns_) + acc_hist.append(acc) + + acc_hist = torch.tensor(acc_hist) + mean_acc = acc_hist.mean() + + return mean_acc.unsqueeze(0) + + def collapse_alignments(self, alignments): + """ + Converts alignments to 1 state per phoneme style. + + Arguments + --------- + alignments : list of ints + Predicted alignments for a single utterance. + + Returns + ------- + sequence : list of ints + The predicted alignments converted to a 1 state per phoneme style. + + Example + ------- + >>> aligner = HMMAligner(states_per_phoneme=3) + >>> alignments = [0, 1, 2, 3, 4, 5, 3, 4, 5, 0, 1, 2] + >>> sequence = aligner.collapse_alignments(alignments) + >>> sequence + [0, 1, 1, 0] + """ + + # Filter the repetitions + sequence = [ + v + for i, v in enumerate(alignments) + if i == 0 or v != alignments[i - 1] + ] + + # Pick out only multiples of self.states_per_phoneme + sequence = [v for v in sequence if v % self.states_per_phoneme == 0] + + # Divide by self.states_per_phoneme + sequence = [v // self.states_per_phoneme for v in sequence] + + return sequence + + @mark_as_saver + def _save(self, path): + torch.save(self.align_dict, path) + + @mark_as_loader + def _load(self, path, end_of_epoch=False): + del end_of_epoch # Not used here. + self.align_dict = torch.load(path) + + +def map_inds_to_intersect(lists1, lists2, ind2labs): + """Converts 2 lists containing indices for phonemes from different + phoneme sets to a single phoneme so that comparing the equality + of the indices of the resulting lists will yield the correct + accuracy. + + Arguments + --------- + lists1 : list of lists of ints + Contains the indices of the first sequence of phonemes. + lists2 : list of lists of ints + Contains the indices of the second sequence of phonemes. + ind2labs : tuple (dict, dict) + Contains the original index-to-label dicts for the first and second + sequence of phonemes. + + Returns + ------- + lists1_new : list of lists of ints + Contains the indices of the first sequence of phonemes, mapped + to the new phoneme set. + lists2_new : list of lists of ints + Contains the indices of the second sequence of phonemes, mapped + to the new phoneme set. + + Example + ------- + >>> lists1 = [[0, 1]] + >>> lists2 = [[0, 1]] + >>> ind2lab1 = { + ... 0: "a", + ... 1: "b", + ... } + >>> ind2lab2 = { + ... 0: "a", + ... 1: "c", + ... } + >>> ind2labs = (ind2lab1, ind2lab2) + >>> out1, out2 = map_inds_to_intersect(lists1, lists2, ind2labs) + >>> out1 + [[0, 1]] + >>> out2 + [[0, 2]] + """ + ind2lab1, ind2lab2 = ind2labs + + # Form 3 sets: + # (1) labs in both mappings + # (2) labs in only 1st mapping + # (3) labs in only 2nd mapping + set1, set2 = set(ind2lab1.values()), set(ind2lab2.values()) + + intersect = set1.intersection(set2) + set1_only = set1.difference(set2) + set2_only = set2.difference(set1) + + new_lab2ind = {lab: i for i, lab in enumerate(intersect)} + new_lab2ind.update( + {lab: len(new_lab2ind) + i for i, lab in enumerate(set1_only)} + ) + new_lab2ind.update( + {lab: len(new_lab2ind) + i for i, lab in enumerate(set2_only)} + ) + + # Map lists to labels and apply new_lab2ind + lists1_lab = [[ind2lab1[ind] for ind in utt] for utt in lists1] + lists2_lab = [[ind2lab2[ind] for ind in utt] for utt in lists2] + + lists1_new = [[new_lab2ind[lab] for lab in utt] for utt in lists1_lab] + lists2_new = [[new_lab2ind[lab] for lab in utt] for utt in lists2_lab] + + return lists1_new, lists2_new + + +def batch_log_matvecmul(A, b): + """For each 'matrix' and 'vector' pair in the batch, do matrix-vector + multiplication in the log domain, i.e., logsumexp instead of add, + add instead of multiply. + + Arguments + --------- + A : torch.Tensor (batch, dim1, dim2) + Tensor + b : torch.Tensor (batch, dim1) + Tensor. + + Returns + ------- + x : torch.Tensor (batch, dim1) + + Example + ------- + >>> A = torch.tensor([[[0.0, 0.0], [-1e5, 0.0]]]) + >>> b = torch.tensor( + ... [ + ... [ + ... 0.0, + ... 0.0, + ... ] + ... ] + ... ) + >>> x = batch_log_matvecmul(A, b) + >>> x + tensor([[0.6931, 0.0000]]) + >>> + >>> # non-log domain equivalent without batching functionality + >>> A_ = torch.tensor([[1.0, 1.0], [0.0, 1.0]]) + >>> b_ = torch.tensor( + ... [ + ... 1.0, + ... 1.0, + ... ] + ... ) + >>> x_ = torch.matmul(A_, b_) + >>> x_ + tensor([2., 1.]) + """ + b = b.unsqueeze(1) + x = torch.logsumexp(A + b, dim=2) + + return x + + +def batch_log_maxvecmul(A, b): + """Similar to batch_log_matvecmul, but takes a maximum instead of + logsumexp. Returns both the max and the argmax. + + Arguments + --------- + A : torch.Tensor (batch, dim1, dim2) + Tensor. + b : torch.Tensor (batch, dim1) + Tensor + + Returns + ------- + x : torch.Tensor (batch, dim1) + Tensor. + argmax : torch.Tensor (batch, dim1) + Tensor. + + Example + ------- + >>> A = torch.tensor([[[0.0, -1.0], [-1e5, 0.0]]]) + >>> b = torch.tensor( + ... [ + ... [ + ... 0.0, + ... 0.0, + ... ] + ... ] + ... ) + >>> x, argmax = batch_log_maxvecmul(A, b) + >>> x + tensor([[0., 0.]]) + >>> argmax + tensor([[0, 1]]) + """ + b = b.unsqueeze(1) + x, argmax = torch.max(A + b, dim=2) + + return x, argmax diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/alignment/ctc_segmentation.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/alignment/ctc_segmentation.py new file mode 100644 index 000000000..72888467e --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/alignment/ctc_segmentation.py @@ -0,0 +1,11 @@ +"""This file ensures old links to speechtokenizer continue to work while providing a Deprecation warning""" + +import warnings + +from speechbrain.integrations.alignment.ctc_seg import * # noqa: F401, F403 + +warnings.warn( + message="speechbrain.alignment.ctc_segmentation has moved to speechbrain.integrations.alignment.ctc_seg", + category=DeprecationWarning, + stacklevel=2, +) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/__init__.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/__init__.py new file mode 100644 index 000000000..81893fb79 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/__init__.py @@ -0,0 +1 @@ +"""Package containing various techniques of data augmentation""" diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/augmenter.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/augmenter.py new file mode 100644 index 000000000..37b79a73f --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/augmenter.py @@ -0,0 +1,544 @@ +"""Classes for implementing data augmentation pipelines. + +Authors + * Mirco Ravanelli 2022 +""" + +import random + +import torch +import torch.nn.functional as F + +from speechbrain.utils.callchains import lengths_arg_exists +from speechbrain.utils.logger import get_logger + +logger = get_logger(__name__) + + +class Augmenter(torch.nn.Module): + """Applies pipelines of data augmentation. + + Arguments + --------- + parallel_augment: bool + If False, the augmentations are applied sequentially with + the order specified in the pipeline argument. + When True, all the N augmentations are concatenated in the output + on the batch axis. + parallel_augment_fixed_bs: bool + If False, each augmenter (performed in parallel) generates a number of + augmented examples equal to the batch size. Thus, overall, with this + option N*batch size artificial data are + generated, where N is the number of augmenters. + When True, the number of total augmented examples is kept fixed at + the batch size, thus, for each augmenter, fixed at batch size // N examples. + This option is useful to keep controlled the number of synthetic examples + with respect to the original data distribution, as it keep always + 50% of original data, and 50% of augmented data. + concat_original: bool + if True, the original input is concatenated with the + augmented outputs (on the batch axis). + min_augmentations: int + The number of augmentations applied to the input signal is randomly + sampled between min_augmentations and max_augmentations. For instance, + if the augmentation dict contains N=6 augmentations and we set + select min_augmentations=1 and max_augmentations=4 we apply up to + M=4 augmentations. The selected augmentations are applied in the order + specified in the augmentations dict. If shuffle_augmentations = True, + a random set of M augmentations is selected. + max_augmentations: int + Maximum number of augmentations to apply. See min_augmentations for + more details. + shuffle_augmentations: bool + If True, it shuffles the entries of the augmentations dictionary. + The effect is to randomply select the order of the augmentations + to apply. + repeat_augment: int + Applies the augmentation algorithm N times. This can be used to + perform more data augmentation. + augment_start_index: int + The index of the first element in the input batch from which data + augmentation should begin. + This argument allows you to specify the starting point for applying + data augmentation. + augment_end_index: int + The index of the last element in the input batch at which data + augmentation should stop. + You can use this argument to define the endpoint for applying data + augmentation within the batch. + concat_start_index: int + If `concat_original` is set to True, you can specify a subpart of the + original batch to concatenate in the output. + Use this argument to select the index of the first element from the + original input batch to start copying from. + concat_end_index: int + If `concat_original` is set to True, you can specify a subpart of the + original batch to concatenate in the output. Use this argument to select + the index of the last element from the original input batch to end the + copying process. + augment_prob: float + The probability (0.0 to 1.0) of applying data augmentation. When set to 0.0, + the original signal is returned without any augmentation. When set to 1.0, + augmentation is always applied. Values in between determine the likelihood + of augmentation. + augmentations: list + List of augmentater objects to combine to perform data augmentation. + enable_augmentations: list + A list of booleans used to selectively enable or disable specific augmentation + techniques within the 'augmentations' list. + Each boolean corresponds to an augmentation object in the 'augmentations' list + and should be of the same length and order. + This feature is useful for performing ablations on augmentation techniques to + tailor them for a specific task. + + Example + ------- + >>> from speechbrain.augment.time_domain import DropFreq, DropChunk + >>> freq_dropper = DropFreq() + >>> chunk_dropper = DropChunk(drop_start=100, drop_end=16000) + >>> augment = Augmenter( + ... parallel_augment=False, + ... concat_original=False, + ... augmentations=[freq_dropper, chunk_dropper], + ... ) + >>> signal = torch.rand([4, 16000]) + >>> output_signal, lengths = augment( + ... signal, lengths=torch.tensor([0.2, 0.5, 0.7, 1.0]) + ... ) + """ + + def __init__( + self, + parallel_augment=False, + parallel_augment_fixed_bs=False, + concat_original=False, + min_augmentations=None, + max_augmentations=None, + shuffle_augmentations=False, + repeat_augment=1, + augment_start_index=0, + augment_end_index=None, + concat_start_index=0, + concat_end_index=None, + augment_prob=1.0, + augmentations=list(), + enable_augmentations=None, + ): + super().__init__() + self.parallel_augment = parallel_augment + self.parallel_augment_fixed_bs = parallel_augment_fixed_bs + self.concat_original = concat_original + self.augmentations = augmentations + self.min_augmentations = min_augmentations + self.max_augmentations = max_augmentations + self.shuffle_augmentations = shuffle_augmentations + self.augment_start_index = augment_start_index + self.augment_end_index = augment_end_index + self.concat_start_index = concat_start_index + self.concat_end_index = concat_end_index + self.repeat_augment = repeat_augment + self.augment_prob = augment_prob + # Check min and max augmentations + self.check_min_max_augmentations() + + # This variable represents the total number of augmentations to perform for each signal, + # including the original signal in the count. + self.num_augmentations = None + self.do_augment = True + + # Check repeat augment arguments + if not isinstance(self.repeat_augment, int): + raise ValueError("repeat_augment must be an integer.") + + if self.repeat_augment < 0: + raise ValueError("repeat_augment must be greater than 0.") + + if self.augment_end_index is not None: + if self.augment_end_index < self.augment_start_index: + raise ValueError( + "augment_end_index must be smaller or equal to augment_start_index." + ) + + if self.concat_end_index is not None: + if self.concat_end_index < self.concat_start_index: + raise ValueError( + "concat_end_index must be smaller or equal to concat_start_index." + ) + + # Managing enable augmentations + if enable_augmentations is None: + enable_augmentations = [True] * len(augmentations) + elif not isinstance(enable_augmentations, list): + raise ValueError("enable_augmentations must be a list.") + elif len(enable_augmentations) != len(augmentations): + raise ValueError( + "enable_augmentations must have the same length as augmentations." + ) + else: + augmentations = [ + aug + for aug, enabled in zip(augmentations, enable_augmentations) + if enabled + ] + + # Turn augmentations into a dictionary + self.augmentations = { + augmentation.__class__.__name__ + str(i): augmentation + for i, augmentation in enumerate(augmentations) + } + + if len(self.augmentations) == 0: + logger.warning( + "No augmentation is applied because the augmentation list is empty." + ) + + # Check min and max augmentations + if self.max_augmentations <= 0: + logger.warning( + "No augmentations applied because max_augmentations is non-positive." + ) + if self.min_augmentations < 0: + self.min_augmentations = 0 + logger.warning( + "min_augmentations is negative. Modified to be non-negative." + ) + if self.min_augmentations > self.max_augmentations: + logger.warning( + "min_augmentations is greater than max_augmentations. min_augmentations set to max_augmentations." + ) + self.max_augmentations = self.min_augmentations + + # Check if augmentation modules need the length argument + self.require_lengths = {} + for aug_key, aug_fun in self.augmentations.items(): + self.require_lengths[aug_key] = lengths_arg_exists(aug_fun.forward) + + def augment(self, x, lengths, selected_augmentations): + """Applies data augmentation on the selected augmentations. + + Arguments + --------- + x : torch.Tensor (batch, time, channel) + input to augment. + lengths : torch.Tensor + The length of each sequence in the batch. + selected_augmentations: dict + Dictionary containing the selected augmentation to apply. + + Returns + ------- + output : torch.Tensor + Augmented outputs. + output_lengths : torch.Tensor + The corresponding length of each output. + """ + next_input = x + next_lengths = lengths + output = [] + output_lengths = [] + out_lengths = lengths + for k, augment_name in enumerate(selected_augmentations): + augment_fun = self.augmentations[augment_name] + + idx = torch.arange(x.shape[0]) + if self.parallel_augment and self.parallel_augment_fixed_bs: + idx_startstop = torch.linspace( + 0, x.shape[0], len(selected_augmentations) + 1 + ).to(torch.int) + idx_start = idx_startstop[k] + idx_stop = idx_startstop[k + 1] + idx = idx[idx_start:idx_stop] + + # Check input arguments + if self.require_lengths[augment_name]: + out = augment_fun( + next_input[idx, ...], lengths=next_lengths[idx] + ) + else: + out = augment_fun(next_input[idx, ...]) + + # Check output arguments + if isinstance(out, tuple): + if len(out) == 2: + out, out_lengths = out + else: + raise ValueError( + "The function must return max two arguments (Tensor, Length[optional])" + ) + + # Manage sequential or parallel augmentation + if not self.parallel_augment: + next_input = out + next_lengths = out_lengths[idx] + else: + output.append(out) + output_lengths.append(out_lengths) + + if self.parallel_augment: + # Concatenate all the augmented data + output, output_lengths = self.concatenate_outputs( + output, output_lengths + ) + else: + # Take the last augmented signal of the pipeline + output = out + output_lengths = out_lengths + + return output, output_lengths + + def forward(self, x, lengths): + """Applies data augmentation. + + Arguments + --------- + x : torch.Tensor (batch, time, channel) + input to augment. + lengths : torch.Tensor + The length of each sequence in the batch. + + Returns + ------- + output : torch.Tensor + Augmented outputs. + output_lengths : torch.Tensor + The corresponding length of each output. + """ + + # Determine whether to apply data augmentation + self.do_augment = True + if random.random() > self.augment_prob: + self.do_augment = False + return x, lengths + + x_original = x + len_original = lengths + + # Determine the ending index for augmentation, considering user-specified or default values. + self.augment_end_index_batch = ( + min(self.augment_end_index, x.shape[0]) + if self.augment_end_index is not None + else x.shape[0] + ) + + # If the augmentation starting index is beyond the size of the data, return the original data. + if self.augment_start_index >= x.shape[0]: + self.do_augment = False + logger.warning( + "No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch." + ) + return x, lengths + + # Select the number of augmentations to apply + self.N_augment = torch.randint( + low=self.min_augmentations, + high=self.max_augmentations + 1, + size=(1,), + device=x.device, + ) + + # Get augmentations list + augmentations_lst = list(self.augmentations.keys()) + + # No augmentation + if ( + self.repeat_augment == 0 + or self.N_augment == 0 + or len(augmentations_lst) == 0 + ): + self.do_augment = False + return x, lengths + + # Shuffle augmentation + if self.shuffle_augmentations: + random.shuffle(augmentations_lst) + + # Select the augmentations to apply + selected_augmentations = augmentations_lst[0 : self.N_augment] + + # Select the portion of the input to augment and update lengths accordingly. + x = x[self.augment_start_index : self.augment_end_index_batch] + lengths = lengths[ + self.augment_start_index : self.augment_end_index_batch + ] + + # Lists to collect the outputs + output_lst = [] + output_len_lst = [] + + # Concatenate the original signal if required + self.skip_concat = not (self.concat_original) + if self.concat_original: + # Check start index + if self.concat_start_index >= x_original.shape[0]: + self.skip_concat = True + pass + else: + self.skip_concat = False + # Determine the ending index for concatenation, considering user-specified or default values. + self.concat_end_index_batch = ( + min(self.concat_end_index, x_original.shape[0]) + if self.concat_end_index is not None + else x_original.shape[0] + ) + + output_lst.append( + x_original[ + self.concat_start_index : self.concat_end_index_batch + ] + ) + output_len_lst.append( + len_original[ + self.concat_start_index : self.concat_end_index_batch + ] + ) + + # Perform augmentations + for i in range(self.repeat_augment): + output, output_lengths = self.augment( + x, lengths, selected_augmentations + ) + output_lst.append(output) + output_len_lst.append(output_lengths) + + # Concatenate the final outputs while handling scenarios where + # different temporal dimensions may arise due to augmentations + # like speed change. + output, output_lengths = self.concatenate_outputs( + output_lst, output_len_lst + ) + + return output, output_lengths + + def concatenate_outputs(self, augment_lst, augment_len_lst): + """ + Concatenate a list of augmented signals, accounting for varying temporal lengths. + Padding is applied to ensure all signals can be concatenated. + + Arguments + --------- + augment_lst : List of torch.Tensor + List of augmented signals to be concatenated. + augment_len_lst : List of torch.Tensor + List of lengths corresponding to the augmented signals. + + Returns + ------- + concatenated_signals : torch.Tensor + A tensor containing the concatenated signals. + concatenated_lengths : torch.Tensor + A tensor containing the concatenated signal lengths. + + Notes + ----- + This function takes a list of augmented signals, which may have different temporal + lengths due to variations such as speed changes. It pads the signals to match the + maximum temporal dimension found among the input signals and rescales the lengths + accordingly before concatenating them. + """ + + # Find the maximum temporal dimension (batch length) among the sequences + max_len = max(augment.shape[1] for augment in augment_lst) + + # Rescale the sequence lengths to adjust for augmented batches with different temporal dimensions. + augment_len_lst = [ + length * (output.shape[1] / max_len) + for length, output in zip(augment_len_lst, augment_lst) + ] + + # Pad sequences to match the maximum temporal dimension. + # Note that some augmented batches, like those with speed changes, may have different temporal dimensions. + augment_lst = [ + F.pad(output, (0, max_len - output.shape[1])) + for output in augment_lst + ] + + # Concatenate the padded sequences and rescaled lengths + output = torch.cat(augment_lst, dim=0) + output_lengths = torch.cat(augment_len_lst, dim=0) + + return output, output_lengths + + def replicate_multiple_labels(self, *args): + """ + Replicates the labels along the batch axis a number of times that + corresponds to the number of augmentations. Indeed parallel and + concatenation augmentations alter the time dimension. + + Arguments + --------- + *args : tuple + Input label tensors to be replicated. Can be a uniq or a list of + torch.Tensors. + + Returns + ------- + augmented_labels: torch.Tensor + Labels corresponding to the augmented input. Returns as many torch.Tensor + as given in input. + """ + + # Determine whether to apply data augmentation + if not self.do_augment: + return args + + list_of_augmented_labels = [] + + for labels in args: + list_of_augmented_labels.append(self.replicate_labels(labels)) + + return list_of_augmented_labels + + def replicate_labels(self, labels): + """ + Replicates the labels along the batch axis a number of times that + corresponds to the number of augmentations. Indeed parallel and + concatenation augmentations alter the time dimension. + + Arguments + --------- + labels : torch.Tensor + Input label tensors to be replicated. + + Returns + ------- + augmented_labels: torch.Tensor + Labels corresponding to the augmented input. Returns as many torch.Tensor + as given in input. + """ + + # Determine whether to apply data augmentation + if not self.do_augment: + return labels + + augmented_labels = [] + if self.concat_original and not (self.skip_concat): + augmented_labels = [ + labels[self.concat_start_index : self.concat_end_index_batch] + ] + selected_labels = labels[ + self.augment_start_index : self.augment_end_index_batch + ] + + if self.parallel_augment: + selected_labels = torch.cat( + [selected_labels] * self.N_augment, dim=0 + ) + + augmented_labels = ( + augmented_labels + [selected_labels] * self.repeat_augment + ) + + augmented_labels = torch.cat(augmented_labels, dim=0) + + return augmented_labels + + def check_min_max_augmentations(self): + """Checks the min_augmentations and max_augmentations arguments.""" + if self.min_augmentations is None: + self.min_augmentations = 1 + if self.max_augmentations is None: + self.max_augmentations = len(self.augmentations) + if self.max_augmentations > len(self.augmentations): + self.max_augmentations = len(self.augmentations) + if self.min_augmentations > len(self.augmentations): + self.min_augmentations = len(self.augmentations) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/codec.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/codec.py new file mode 100644 index 000000000..50c2953cc --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/codec.py @@ -0,0 +1,92 @@ +""" +Codec Augmentation via torchaudio + +This library provides codec augmentation techniques in torchaudio for enhanced +audio data processing. + +For detailed guidance and usage examples, refer to the tutorial at: +https://pytorch.org/audio/stable/tutorials/audio_data_augmentation_tutorial.html + +Note: This code is compatible with FFmpeg as the torchaudio backend. +When using FFmpeg2, the maximum number of samples for processing is limited to 16. + +Authors + * Mirco Ravanelli 2023 +""" + +import random + +import torch +import torchaudio + + +class CodecAugment(torch.nn.Module): + """ + Apply random audio codecs to input waveforms using torchaudio. + + This class provides an interface for applying codec augmentation techniques to audio data. + + Arguments + --------- + sample_rate: int + The sample rate of the input waveform. + + Example + ------- + >>> waveform = torch.rand(4, 16000) + >>> if torchaudio.list_audio_backends()[0] == "ffmpeg": + ... augmenter = CodecAugment(16000) + ... output_waveform = augmenter(waveform) + """ + + def __init__(self, sample_rate=16000): + super().__init__() + self.sample_rate = sample_rate + self.available_format_encoders = [ + ("wav", "pcm_mulaw"), + ("mp3", None), + ("g722", None), + ] + + def apply_codec(self, waveform, format=None, encoder=None): + """ + Apply the selected audio codec. + + Arguments + ---------- + waveform: torch.Tensor + Input waveform of shape `[batch, time]`. + format: str + The audio format to use (e.g., "wav", "mp3"). Default is None. + encoder: str + The encoder to use for the format (e.g., "opus", "vorbis"). Default is None. + + Returns + --------- + torch.Tensor: + Coded version of the input waveform of shape `[batch, time]`. + """ + audio_effector = torchaudio.io.AudioEffector( + format=format, encoder=encoder + ) + waveform_aug = audio_effector.apply( + waveform.transpose(0, 1).to("cpu"), self.sample_rate + ) + return waveform_aug.transpose(0, 1).to(waveform.device) + + def forward(self, waveform): + """ + Apply a random audio codec from the available list. + + Arguments + --------- + waveform: torch.Tensor + Input waveform of shape `[batch, time]`. + + Returns + ------- + torch.Tensor + Coded version of the input waveform of shape `[batch, time]`. + """ + format, encoder = random.choice(self.available_format_encoders) + return self.apply_codec(waveform, format=format, encoder=encoder) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/freq_domain.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/freq_domain.py new file mode 100644 index 000000000..4a2acb64f --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/freq_domain.py @@ -0,0 +1,399 @@ +"""Frequency-Domain Sequential Data Augmentation Classes + +This module comprises classes tailored for augmenting sequential data in the +frequency domain, such as spectrograms and mel spectrograms. +Its primary purpose is to enhance the resilience of neural models during the training process. + +Authors: +- Peter Plantinga (2020) +- Mirco Ravanelli (2023) +""" + +import random + +import torch + + +class SpectrogramDrop(torch.nn.Module): + """This class drops slices of the input spectrogram. + + Using `SpectrogramDrop` as an augmentation strategy helps a models learn to rely + on all parts of the signal, since it can't expect a given part to be + present. + + Reference: + https://arxiv.org/abs/1904.08779 + + Arguments + --------- + drop_length_low : int + The low end of lengths for which to drop the + spectrogram, in samples. + drop_length_high : int + The high end of lengths for which to drop the + signal, in samples. + drop_count_low : int + The low end of number of times that the signal + can be dropped. + drop_count_high : int + The high end of number of times that the signal + can be dropped. + replace: str + - 'zeros': Masked values are replaced with zeros. + - 'mean': Masked values are replaced with the mean value of the spectrogram. + - 'rand': Masked values are replaced with random numbers ranging between + the maximum and minimum values of the spectrogram. + - 'cutcat': Masked values are replaced with chunks from other signals in the batch. + - 'swap': Masked values are replaced with other chunks from the same sentence. + - 'random_selection': A random selection among the approaches above. + dim : int + Corresponding dimension to mask. If dim=1, we apply time masking. + If dim=2, we apply frequency masking. + + Example + ------- + >>> # time-masking + >>> drop = SpectrogramDrop(dim=1) + >>> spectrogram = torch.rand(4, 150, 40) + >>> print(spectrogram.shape) + torch.Size([4, 150, 40]) + >>> out = drop(spectrogram) + >>> print(out.shape) + torch.Size([4, 150, 40]) + >>> # frequency-masking + >>> drop = SpectrogramDrop(dim=2) + >>> spectrogram = torch.rand(4, 150, 40) + >>> print(spectrogram.shape) + torch.Size([4, 150, 40]) + >>> out = drop(spectrogram) + >>> print(out.shape) + torch.Size([4, 150, 40]) + """ + + def __init__( + self, + drop_length_low=5, + drop_length_high=15, + drop_count_low=1, + drop_count_high=3, + replace="zeros", + dim=1, + ): + super().__init__() + self.drop_length_low = drop_length_low + self.drop_length_high = drop_length_high + self.drop_count_low = drop_count_low + self.drop_count_high = drop_count_high + self.replace = replace + self.dim = dim + + # Validate low < high + if drop_length_low > drop_length_high: + raise ValueError("Low limit must not be more than high limit") + if drop_count_low > drop_count_high: + raise ValueError("Low limit must not be more than high limit") + + self.replace_opts = [ + "zeros", + "mean", + "rand", + "cutcat", + "swap", + "random_selection", + ] + if self.replace not in self.replace_opts: + raise ValueError( + f"Invalid 'replace' option. Select one of {', '.join(self.replace_opts)}" + ) + + def forward(self, spectrogram): + """ + Apply the DropChunk augmentation to the input spectrogram. + + This method randomly drops chunks of the input spectrogram to augment the data. + + Arguments + --------- + spectrogram : torch.Tensor + Input spectrogram of shape `[batch, time, fea]`. + + Returns + ------- + torch.Tensor + Augmented spectrogram of shape `[batch, time, fea]`. + """ + + # Manage 4D tensors + if spectrogram.dim() == 4: + spectrogram = spectrogram.view( + -1, spectrogram.shape[2], spectrogram.shape[3] + ) + + # Get the batch size + batch_size, time_duration, fea_size = spectrogram.shape + + # Managing masking dimensions + if self.dim == 1: + D = time_duration + else: + D = fea_size + + # Randomly select the number of chunks to drop (same for all samples in the batch) + n_masks = torch.randint( + low=self.drop_count_low, + high=self.drop_count_high + 1, + size=(1,), + device=spectrogram.device, + ) + + # If the number of chunks to drop is 0, return the spectrogram unchanged + if n_masks == 0: + return spectrogram + + # Randomly sample the lengths of the chunks to drop + mask_len = torch.randint( + low=self.drop_length_low, + high=self.drop_length_high, + size=(batch_size, n_masks), + device=spectrogram.device, + ).unsqueeze(2) + + # Randomly sample the positions of the chunks to drop + mask_pos = torch.randint( + 0, + max(1, D, -mask_len.max()), + (batch_size, n_masks), + device=spectrogram.device, + ).unsqueeze(2) + + # Compute the mask for the selected chunk positions + arange = torch.arange(D, device=spectrogram.device).view(1, 1, -1) + mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len)) + mask = mask.any(dim=1) + mask = mask.unsqueeze(2) if self.dim == 1 else mask.unsqueeze(1) + + # Determine the value to replace the masked chunks (zero or mean of the spectrogram) + if self.replace == "random_selection": + self.replace = random.choice(self.replace_opts[:-1]) + + if self.replace == "zeros": + spectrogram = spectrogram.masked_fill_(mask, 0.0) + elif self.replace == "mean": + mean = spectrogram.mean().detach() + spectrogram = spectrogram.masked_fill_(mask, mean) + elif self.replace == "rand": + max_spectrogram = spectrogram.max().detach() + min_spectrogram = spectrogram.min().detach() + rand_spectrogram = torch.rand_like(spectrogram) + rand_spectrogram = ( + rand_spectrogram * (max_spectrogram - min_spectrogram) + + min_spectrogram + ) + mask = mask.float() + spectrogram = (1 - mask) * spectrogram + mask * rand_spectrogram + elif self.replace == "cutcat": + rolled_spectrogram = torch.roll(spectrogram, shifts=1, dims=0) + mask = mask.float() + spectrogram = (1 - mask) * spectrogram + mask * rolled_spectrogram + elif self.replace == "swap": + shift = torch.randint( + low=1, + high=spectrogram.shape[1], + size=(1,), + device=spectrogram.device, + ) + rolled_spectrogram = torch.roll( + spectrogram, shifts=shift.item(), dims=1 + ) + mask = mask.float() + spectrogram = (1 - mask) * spectrogram + mask * rolled_spectrogram + + return spectrogram.view(*spectrogram.shape) + + +class Warping(torch.nn.Module): + """ + Apply time or frequency warping to a spectrogram. + + If `dim=1`, time warping is applied; if `dim=2`, frequency warping is applied. + This implementation selects a center and a window length to perform warping. + It ensures that the temporal dimension remains unchanged by upsampling or + downsampling the affected regions accordingly. + + Reference: + https://arxiv.org/abs/1904.08779 + + Arguments + --------- + warp_window : int, optional + The width of the warping window. Default is 5. + warp_mode : str, optional + The interpolation mode for time warping. Default is "bicubic." + dim : int, optional + Dimension along which to apply warping (1 for time, 2 for frequency). + Default is 1. + + Example + ------- + >>> # Time-warping + >>> warp = Warping() + >>> spectrogram = torch.rand(4, 150, 40) + >>> print(spectrogram.shape) + torch.Size([4, 150, 40]) + >>> out = warp(spectrogram) + >>> print(out.shape) + torch.Size([4, 150, 40]) + >>> # Frequency-warping + >>> warp = Warping(dim=2) + >>> spectrogram = torch.rand(4, 150, 40) + >>> print(spectrogram.shape) + torch.Size([4, 150, 40]) + >>> out = warp(spectrogram) + >>> print(out.shape) + torch.Size([4, 150, 40]) + """ + + def __init__(self, warp_window=5, warp_mode="bicubic", dim=1): + super().__init__() + self.warp_window = warp_window + self.warp_mode = warp_mode + self.dim = dim + + def forward(self, spectrogram): + """ + Apply warping to the input spectrogram. + + Arguments + --------- + spectrogram : torch.Tensor + Input spectrogram with shape `[batch, time, fea]`. + + Returns + ------- + torch.Tensor + Augmented spectrogram with shape `[batch, time, fea]`. + """ + + # Set warping dimension + if self.dim == 2: + spectrogram = spectrogram.transpose(1, 2) + + original_size = spectrogram.shape + window = self.warp_window + + # 2d interpolation requires 4D or higher dimension tensors + # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq) + if spectrogram.dim() == 3: + spectrogram = spectrogram.unsqueeze(1) + + len_original = spectrogram.shape[2] + if len_original - window <= window: + return spectrogram.view(*original_size) + + # Compute center and corresponding window + c = torch.randint(window, len_original - window, (1,))[0] + w = torch.randint(c - window, c + window, (1,))[0] + 1 + + # Update the left part of the spectrogram + left = torch.nn.functional.interpolate( + spectrogram[:, :, :c], + (w, spectrogram.shape[3]), + mode=self.warp_mode, + align_corners=True, + ) + + # Update the right part of the spectrogram. + # When the left part is expanded, the right part is compressed by the + # same factor, and vice versa. + right = torch.nn.functional.interpolate( + spectrogram[:, :, c:], + (len_original - w, spectrogram.shape[3]), + mode=self.warp_mode, + align_corners=True, + ) + + # Injecting the warped left and right parts. + spectrogram[:, :, :w] = left + spectrogram[:, :, w:] = right + spectrogram = spectrogram.view(*original_size) + + # Transpose if freq warping is applied. + if self.dim == 2: + spectrogram = spectrogram.transpose(1, 2) + + return spectrogram + + +class RandomShift(torch.nn.Module): + """Shifts the input tensor by a random amount, allowing for either a time + or frequency (or channel) shift depending on the specified axis. + It is crucial to calibrate the minimum and maximum shifts according to the + requirements of your specific task. + We recommend using small shifts to preserve information integrity. + Using large shifts may result in the loss of significant data and could + potentially lead to misalignments with corresponding labels. + + Arguments + --------- + min_shift : int + The minimum channel shift. + max_shift : int + The maximum channel shift. + dim: int + The dimension to shift. + + Example + ------- + >>> # time shift + >>> signal = torch.zeros(4, 100, 80) + >>> signal[0, 50, :] = 1 + >>> rand_shift = RandomShift(dim=1, min_shift=-10, max_shift=10) + >>> lengths = torch.tensor([0.2, 0.8, 0.9, 1.0]) + >>> output_signal, lengths = rand_shift(signal, lengths) + + >>> # frequency shift + >>> signal = torch.zeros(4, 100, 80) + >>> signal[0, :, 40] = 1 + >>> rand_shift = RandomShift(dim=2, min_shift=-10, max_shift=10) + >>> lengths = torch.tensor([0.2, 0.8, 0.9, 1.0]) + >>> output_signal, lengths = rand_shift(signal, lengths) + """ + + def __init__(self, min_shift=0, max_shift=0, dim=1): + super().__init__() + self.min_shift = min_shift + self.max_shift = max_shift + self.dim = dim + + # Check arguments + if self.max_shift < self.min_shift: + raise ValueError("max_shift must be >= min_shift") + + def forward(self, waveforms, lengths): + """ + Arguments + --------- + waveforms : tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + lengths : tensor + Shape should be a single dimension, `[batch]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]` + """ + # Pick a frequency to drop + N_shifts = torch.randint( + low=self.min_shift, + high=self.max_shift + 1, + size=(1,), + device=waveforms.device, + ) + waveforms = torch.roll(waveforms, shifts=N_shifts.item(), dims=self.dim) + + # Update lengths in the case of temporal shift. + if self.dim == 1: + lengths = lengths + N_shifts / waveforms.shape[self.dim] + lengths = torch.clamp(lengths, min=0.0, max=1.0) + + return waveforms, lengths diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/preparation.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/preparation.py new file mode 100644 index 000000000..3795cadee --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/preparation.py @@ -0,0 +1,219 @@ +"""Library for Downloading and Preparing Datasets for Data Augmentation, +This library provides functions for downloading datasets from the web and +preparing the necessary CSV data manifest files for use by data augmenters. + +Authors: +* Mirco Ravanelli 2023 + +""" + +import os +import pathlib + +from speechbrain.dataio import audio_io +from speechbrain.utils.data_utils import download_file, get_all_files +from speechbrain.utils.distributed import main_process_only +from speechbrain.utils.logger import get_logger + +# Logger init +logger = get_logger(__name__) + + +@main_process_only +def prepare_dataset_from_URL(URL, dest_folder, ext, csv_file, max_length=None): + """Downloads a dataset containing recordings (e.g., noise sequences) + from the provided URL and prepares the necessary CSV files for use by the noise augmenter. + + Arguments + --------- + URL : str + The URL of the dataset to download. + dest_folder : str + The local folder where the noisy dataset will be downloaded. + ext : str + File extensions to search for within the downloaded dataset. + csv_file : str + The path to store the prepared noise CSV file. + max_length : float + The maximum length in seconds. + Recordings longer than this will be automatically cut into pieces. + """ + + # Download and unpack if necessary + data_file = os.path.join(dest_folder, "data.zip") + + if not os.path.isdir(dest_folder): + download_file(URL, data_file, unpack=True) + else: + download_file(URL, data_file) + + # Prepare noise csv if necessary + if not os.path.isfile(csv_file): + filelist = get_all_files(dest_folder, match_and=["." + ext]) + prepare_csv(filelist, csv_file, max_length) + + +@main_process_only +def prepare_csv(filelist, csv_file, max_length=None): + """Iterate a set of wavs and write the corresponding csv file. + + Arguments + --------- + filelist : str + A list containing the paths of files of interest. + csv_file : str + The path to store the prepared noise CSV file. + max_length : float + The maximum length in seconds. + Recordings longer than this will be automatically cut into pieces. + """ + try: + write_csv(filelist, csv_file, max_length) + except Exception as e: + # Handle the exception or log the error message + logger.error("Exception:", exc_info=(e)) + + # Delete the file if something fails + if os.path.exists(csv_file): + os.remove(csv_file) + + +@main_process_only +def write_csv(filelist, csv_file, max_length=None): + """ + Iterate through a list of audio files and write the corresponding CSV file. + + Arguments + --------- + filelist : list of str + A list containing the paths of audio files of interest. + csv_file : str + The path where to store the prepared noise CSV file. + max_length : float (optional) + The maximum recording length in seconds. + Recordings longer than this will be automatically cut into pieces. + """ + with open(csv_file, "w", encoding="utf-8") as w: + w.write("ID,duration,wav,wav_format,wav_opts\n") + for i, filename in enumerate(filelist): + _write_csv_row(w, filename, i, max_length) + + +def _write_csv_row(w, filename, index, max_length): + """ + Write a single row to the CSV file based on the audio file information. + + Arguments + --------- + w : file + The open CSV file for writing. + filename : str + The path to the audio file. + index : int + The index of the audio file in the list. + max_length : float (optional) + The maximum recording length in seconds. + """ + signal, rate = audio_io.load(filename) + signal = _ensure_single_channel(signal, filename, rate) + + ID, ext = os.path.basename(filename).split(".") + duration = signal.shape[1] / rate + + if max_length is not None and duration > max_length: + _handle_long_waveform( + w, filename, ID, ext, signal, rate, duration, max_length, index + ) + else: + _write_short_waveform_csv(w, ID, ext, duration, filename, index) + + +def _ensure_single_channel(signal, filename, rate): + """ + Ensure that the audio signal has only one channel. + + Arguments + --------- + signal : torch.Tensor + The audio signal. + filename : str + The path to the audio file. + rate : int + The sampling frequency of the signal. + + Returns + ------- + signal : Torch.Tensor + The audio signal with a single channel. + """ + if signal.shape[0] > 1: + signal = signal[0].unsqueeze(0) + audio_io.save(filename, signal, rate) + return signal + + +def _handle_long_waveform( + w, filename, ID, ext, signal, rate, duration, max_length, index +): + """ + Handle long audio waveforms by cutting them into pieces and writing to the CSV. + + Arguments + --------- + w : file + The open CSV file for writing. + filename : str + The path to the audio file. + ID : str + The unique identifier for the audio. + ext : str + The audio file extension. + signal : torch.Tensor + The audio signal. + rate : int + The audio sample rate. + duration : float + The duration of the audio in seconds. + max_length : float + The maximum recording length in seconds. + index : int + The index of the audio file in the list. + """ + os.remove(filename) + filename = pathlib.Path(filename) + for j in range(int(duration / max_length)): + start = int(max_length * j * rate) + stop = int(min(max_length * (j + 1), duration) * rate) + new_filename = filename.with_stem(filename.stem + f"_{j}") + + audio_io.save(new_filename, signal[:, start:stop], rate) + csv_row = ( + f"{ID}_{index}_{j}", + str((stop - start) / rate), + str(new_filename), + ext, + "\n", + ) + w.write(",".join(csv_row)) + + +def _write_short_waveform_csv(w, ID, ext, duration, filename, index): + """ + Write a CSV row for a short audio waveform. + + Arguments + --------- + w : file + The open CSV file for writing. + ID : str + The unique identifier for the audio. + ext : str + The audio file extension. + duration : float + The duration of the audio in seconds. + filename : str + The path to the audio file. + index : int + The index of the audio file in the list. + """ + w.write(",".join((f"{ID}_{index}", str(duration), filename, ext, "\n"))) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/time_domain.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/time_domain.py new file mode 100644 index 000000000..9db2d05f7 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/augment/time_domain.py @@ -0,0 +1,1540 @@ +"""Time-Domain Sequential Data Augmentation Classes + +This module contains classes designed for augmenting sequential data in the time domain. +It is particularly useful for enhancing the robustness of neural models during training. +The available data distortions include adding noise, applying reverberation, adjusting playback speed, and more. +All classes are implemented as `torch.nn.Module`, enabling end-to-end differentiability and gradient backpropagation. + +Authors: +- Peter Plantinga (2020) +- Mirco Ravanelli (2023) +- Gianfranco Dumoulin Bertucci (2025) +""" + +# Importing libraries +import random + +import torch +import torch.nn.functional as F +import torchaudio + +from speechbrain.dataio.dataloader import make_dataloader +from speechbrain.dataio.legacy import ExtendedCSVDataset +from speechbrain.processing.signal_processing import ( + compute_amplitude, + convolve1d, + dB_to_amplitude, + notch_filter, + reverberate, +) + + +class AddNoise(torch.nn.Module): + """This class additively combines a noise signal to the input signal. + + Arguments + --------- + csv_file : str + The name of a csv file containing the location of the + noise audio files. If none is provided, white noise will be used. + csv_keys : list, None, optional + Default: None . One data entry for the noise data should be specified. + If None, the csv file is expected to have only one data entry. + sorting : str + The order to iterate the csv file, from one of the + following options: random, original, ascending, and descending. + num_workers : int + Number of workers in the DataLoader (See PyTorch DataLoader docs). + snr_low : int + The low end of the mixing ratios, in decibels. + snr_high : int + The high end of the mixing ratios, in decibels. + pad_noise : bool + If True, copy noise signals that are shorter than + their corresponding clean signals so as to cover the whole clean + signal. Otherwise, leave the noise un-padded. + start_index : int + The index in the noise waveforms to start from. By default, chooses + a random index in [0, len(noise) - len(waveforms)]. + normalize : bool + If True, output noisy signals that exceed [-1,1] will be + normalized to [-1,1]. + noise_funct: funct object + function to use to draw a noisy sample. It is enabled if the csv files + containing the noisy sequences are not provided. By default, + torch.randn_like is used (to sample white noise). In general, it must + be a function that takes in input the original waveform and returns + a tensor with the corresponding noise to add (e.g., see pink_noise_like). + replacements : dict + A set of string replacements to carry out in the + csv file. Each time a key is found in the text, it will be replaced + with the corresponding value. + noise_sample_rate : int + The sample rate of the noise audio signals, so noise can be resampled + to the clean sample rate if necessary. + clean_sample_rate : int + The sample rate of the clean audio signals, so noise can be resampled + to the clean sample rate if necessary. + + Example + ------- + >>> import pytest + >>> from speechbrain.dataio.dataio import read_audio + >>> signal = read_audio("tests/samples/single-mic/example1.wav") + >>> clean = signal.unsqueeze(0) # [batch, time, channels] + >>> noisifier = AddNoise( + ... "tests/samples/annotation/noise.csv", + ... replacements={"noise_folder": "tests/samples/noise"}, + ... ) + >>> noisy = noisifier(clean, torch.ones(1)) + """ + + def __init__( + self, + csv_file=None, + csv_keys=None, + sorting="random", + num_workers=0, + snr_low=0, + snr_high=0, + pad_noise=False, + start_index=None, + normalize=False, + noise_funct=torch.randn_like, + replacements={}, + noise_sample_rate=16000, + clean_sample_rate=16000, + ): + super().__init__() + + self.csv_file = csv_file + self.csv_keys = csv_keys + self.sorting = sorting + self.num_workers = num_workers + self.snr_low = snr_low + self.snr_high = snr_high + self.pad_noise = pad_noise + self.start_index = start_index + self.normalize = normalize + self.replacements = replacements + self.noise_funct = noise_funct + self.noise_sample_rate = noise_sample_rate + self.clean_sample_rate = clean_sample_rate + + def forward(self, waveforms, lengths): + """ + Arguments + --------- + waveforms : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + lengths : torch.Tensor + Shape should be a single dimension, `[batch]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]`. + """ + + # Copy clean waveform to initialize noisy waveform + noisy_waveform = waveforms.clone() + lengths = (lengths * waveforms.shape[1]).unsqueeze(1) + + # Compute the average amplitude of the clean waveforms + clean_amplitude = compute_amplitude(waveforms, lengths, amp_type="rms") + + # Pick an SNR and use it to compute the mixture amplitude factors + SNR = torch.rand(len(waveforms), 1, device=waveforms.device) + SNR = SNR * (self.snr_high - self.snr_low) + self.snr_low + noise_amplitude_factor = 1 / (dB_to_amplitude(SNR) + 1) + + # Support for multichannel waveforms + if len(noisy_waveform.shape) == 3: + noise_amplitude_factor = noise_amplitude_factor.unsqueeze(1) + + # Scale clean signal appropriately + new_noise_amplitude = noise_amplitude_factor * clean_amplitude + noisy_waveform *= 1 - noise_amplitude_factor + + # Loop through clean samples and create mixture + if self.csv_file is None: + noise_waveform = self.noise_funct(waveforms) + if noise_waveform.shape[0] == 1: + noise_waveform = torch.cat( + [noise_waveform] * waveforms.shape[0], dim=0 + ) + + noise_length = lengths + else: + tensor_length = waveforms.shape[1] + noise_waveform, noise_length = self._load_noise( + lengths, tensor_length + ) + + # Rescale and add + noise_amplitude = compute_amplitude( + noise_waveform, noise_length, amp_type="rms" + ) + noise_waveform *= new_noise_amplitude / (noise_amplitude + 1e-14) + + noisy_waveform += noise_waveform + # Normalizing to prevent clipping + if self.normalize: + abs_max, _ = torch.max( + torch.abs(noisy_waveform), dim=1, keepdim=True + ) + noisy_waveform = noisy_waveform / abs_max.clamp(min=1.0) + + return noisy_waveform + + def _load_noise(self, lengths, max_length): + """Load a batch of noises""" + lengths = lengths.long().squeeze(1) + batch_size = len(lengths) + + # Load a noise batch + if not hasattr(self, "data_loader"): + if self.noise_sample_rate != self.clean_sample_rate: + self.resampler = Resample( + self.noise_sample_rate, self.clean_sample_rate + ) + + # Set parameters based on input + self.device = lengths.device + + # Create a data loader for the noise wavforms + if self.csv_file is not None: + dataset = ExtendedCSVDataset( + csvpath=self.csv_file, + output_keys=self.csv_keys, + sorting=( + self.sorting if self.sorting != "random" else "original" + ), + replacements=self.replacements, + ) + self.data_loader = make_dataloader( + dataset, + batch_size=batch_size, + num_workers=self.num_workers, + shuffle=(self.sorting == "random"), + ) + self.noise_data = iter(self.data_loader) + + # Load noise to correct device + noise_batch, noise_len = self._load_noise_batch_of_size(batch_size) + noise_batch = noise_batch.to(lengths.device) + noise_len = noise_len.to(lengths.device) + + # Resample noise if necessary + if hasattr(self, "resampler"): + noise_batch = self.resampler(noise_batch) + + # Convert relative length to an index + noise_len = (noise_len * noise_batch.shape[1]).long() + + # Ensure shortest wav can cover speech signal + # WARNING: THIS COULD BE SLOW IF THERE ARE VERY SHORT NOISES + if self.pad_noise: + while torch.any(noise_len < lengths): + min_len = torch.min(noise_len) + prepend = noise_batch[:, :min_len] + noise_batch = torch.cat((prepend, noise_batch), axis=1) + noise_len += min_len + + # Ensure noise batch is long enough + elif noise_batch.size(1) < max_length: + padding = (0, max_length - noise_batch.size(1)) + noise_batch = torch.nn.functional.pad(noise_batch, padding) + + # Select a random starting location in the waveform + start_index = self.start_index + if self.start_index is None: + start_index = 0 + max_chop = (noise_len - lengths).min().clamp(min=1) + start_index = torch.randint( + high=max_chop, size=(1,), device=lengths.device + ) + + # Truncate noise_batch to max_length + noise_batch = noise_batch[:, start_index : start_index + max_length] + noise_len = (noise_len - start_index).clamp(max=max_length).unsqueeze(1) + return noise_batch, noise_len + + def _load_noise_batch_of_size(self, batch_size): + """Concatenate noise batches, then chop to correct size""" + + noise_batch, noise_lens = self._load_noise_batch() + + # Expand + while len(noise_batch) < batch_size: + added_noise, added_lens = self._load_noise_batch() + noise_batch, noise_lens = AddNoise._concat_batch( + noise_batch, noise_lens, added_noise, added_lens + ) + + # Contract + if len(noise_batch) > batch_size: + noise_batch = noise_batch[:batch_size] + noise_lens = noise_lens[:batch_size] + + return noise_batch, noise_lens + + @staticmethod + def _concat_batch(noise_batch, noise_lens, added_noise, added_lens): + """Concatenate two noise batches of potentially different lengths""" + + # pad shorter batch to correct length + noise_tensor_len = noise_batch.shape[1] + added_tensor_len = added_noise.shape[1] + pad = (0, abs(noise_tensor_len - added_tensor_len)) + if noise_tensor_len > added_tensor_len: + added_noise = torch.nn.functional.pad(added_noise, pad) + added_lens = added_lens * added_tensor_len / noise_tensor_len + else: + noise_batch = torch.nn.functional.pad(noise_batch, pad) + noise_lens = noise_lens * noise_tensor_len / added_tensor_len + + noise_batch = torch.cat((noise_batch, added_noise)) + noise_lens = torch.cat((noise_lens, added_lens)) + + return noise_batch, noise_lens + + def _load_noise_batch(self): + """Load a batch of noises, restarting iteration if necessary.""" + + try: + # Don't necessarily know the key + noises, lens = next(self.noise_data).at_position(0) + except StopIteration: + self.noise_data = iter(self.data_loader) + noises, lens = next(self.noise_data).at_position(0) + return noises, lens + + +class AddReverb(torch.nn.Module): + """This class convolves an audio signal with an impulse response. + + Arguments + --------- + csv_file : str + The name of a csv file containing the location of the + impulse response files. + sorting : str + The order to iterate the csv file, from one of + the following options: random, original, ascending, and descending. + num_workers : int + Number of workers in the DataLoader (See PyTorch DataLoader docs). + rir_scale_factor: float + It compresses or dilates the given impulse response. + If 0 < scale_factor < 1, the impulse response is compressed + (less reverb), while if scale_factor > 1 it is dilated + (more reverb). + replacements : dict + A set of string replacements to carry out in the + csv file. Each time a key is found in the text, it will be replaced + with the corresponding value. + reverb_sample_rate : int + The sample rate of the corruption signals (rirs), so that they + can be resampled to clean sample rate if necessary. + clean_sample_rate : int + The sample rate of the clean signals, so that the corruption + signals can be resampled to the clean sample rate before convolution. + + Example + ------- + >>> import pytest + >>> from speechbrain.dataio.dataio import read_audio + >>> signal = read_audio("tests/samples/single-mic/example1.wav") + >>> clean = signal.unsqueeze(0) # [batch, time, channels] + >>> reverb = AddReverb( + ... "tests/samples/annotation/RIRs.csv", + ... replacements={"rir_folder": "tests/samples/RIRs"}, + ... ) + >>> reverbed = reverb(clean) + """ + + def __init__( + self, + csv_file, + sorting="random", + num_workers=0, + rir_scale_factor=1.0, + replacements={}, + reverb_sample_rate=16000, + clean_sample_rate=16000, + ): + super().__init__() + self.csv_file = csv_file + self.sorting = sorting + self.num_workers = num_workers + self.replacements = replacements + self.reverb_sample_rate = reverb_sample_rate + self.clean_sample_rate = clean_sample_rate + self.rir_scale_factor = rir_scale_factor + + def forward(self, waveforms): + """ + Arguments + --------- + waveforms : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]`. + """ + + if self.reverb_sample_rate != self.clean_sample_rate: + self.resampler = Resample( + self.reverb_sample_rate, self.clean_sample_rate + ) + + # Add channels dimension if necessary + channel_added = False + if len(waveforms.shape) == 2: + waveforms = waveforms.unsqueeze(-1) + channel_added = True + + # Load and prepare RIR + rir_waveform = self._load_rir(waveforms) + + # Resample to correct rate + if hasattr(self, "resampler"): + rir_waveform = self.resampler(rir_waveform) + + # Compress or dilate RIR + if self.rir_scale_factor != 1: + rir_waveform = F.interpolate( + rir_waveform.transpose(1, -1), + scale_factor=self.rir_scale_factor, + mode="linear", + align_corners=False, + ) + rir_waveform = rir_waveform.transpose(1, -1) + + rev_waveform = reverberate(waveforms, rir_waveform, rescale_amp="avg") + + # Remove channels dimension if added + if channel_added: + return rev_waveform.squeeze(-1) + + return rev_waveform + + def _load_rir(self, waveforms): + # Create a data loader for the RIR waveforms + if not hasattr(self, "data_loader"): + dataset = ExtendedCSVDataset( + csvpath=self.csv_file, + sorting=( + self.sorting if self.sorting != "random" else "original" + ), + replacements=self.replacements, + ) + self.data_loader = make_dataloader( + dataset, + shuffle=(self.sorting == "random"), + num_workers=self.num_workers, + ) + self.rir_data = iter(self.data_loader) + + try: + rir_waveform, length = next(self.rir_data).at_position(0) + except StopIteration: + self.rir_data = iter(self.data_loader) + rir_waveform, length = next(self.rir_data).at_position(0) + + # Make sure RIR has correct channels + if len(rir_waveform.shape) == 2: + rir_waveform = rir_waveform.unsqueeze(-1) + + # Make sure RIR has correct type and device + rir_waveform = rir_waveform.type(waveforms.dtype) + return rir_waveform.to(waveforms.device) + + +class SpeedPerturb(torch.nn.Module): + """Slightly speed up or slow down an audio signal. + + Resample the audio signal at a rate that is similar to the original rate, + to achieve a slightly slower or slightly faster signal. This technique is + outlined in the paper: "Audio Augmentation for Speech Recognition" + + Arguments + --------- + orig_freq : int + The frequency of the original signal. + speeds : list + The speeds that the signal should be changed to, as a percentage of the + original signal (i.e. `speeds` is divided by 100 to get a ratio). + device : str + The device to use for the resampling. + + Example + ------- + >>> from speechbrain.dataio.dataio import read_audio + >>> signal = read_audio("tests/samples/single-mic/example1.wav") + >>> perturbator = SpeedPerturb(orig_freq=16000, speeds=[90]) + >>> clean = signal.unsqueeze(0) + >>> perturbed = perturbator(clean) + >>> clean.shape + torch.Size([1, 52173]) + >>> perturbed.shape + torch.Size([1, 57971]) + """ + + def __init__(self, orig_freq, speeds=[90, 100, 110], device="cpu"): + super().__init__() + self.orig_freq = orig_freq + self.speeds = speeds + self.device = device + # Initialize index of perturbation + self.samp_index = 0 + + # Initialize resamplers + self.resamplers = [] + for speed in self.speeds: + config = { + "orig_freq": self.orig_freq, + "new_freq": round(self.orig_freq * 100 / speed), + } + self.resamplers.append(Resample(**config)) + + def forward(self, waveform): + """ + Arguments + --------- + waveform : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + + Returns + ------- + torch.Tensor of shape `[batch, time]` or `[batch, time, channels]`. + """ + + # Perform a random perturbation + self.samp_index = torch.randint(0, len(self.speeds), (1,)) + perturbed_waveform = self.resamplers[self.samp_index]( + waveform.to(self.device) + ) + # Move back from host to original device + return perturbed_waveform.to(waveform.device) + + +class Resample(torch.nn.Module): + """This class resamples audio using the + :class:`torchaudio resampler ` based on + sinc interpolation. + + Arguments + --------- + orig_freq : int + the sampling frequency of the input signal. + new_freq : int + the new sampling frequency after this operation is performed. + *args + additional arguments forwarded to the + :class:`torchaudio.transforms.Resample` constructor + **kwargs + additional keyword arguments forwarded to the + :class:`torchaudio.transforms.Resample` constructor + + Example + ------- + >>> from speechbrain.dataio.dataio import read_audio + >>> signal = read_audio("tests/samples/single-mic/example1.wav") + >>> signal = signal.unsqueeze(0) # [batch, time, channels] + >>> resampler = Resample(orig_freq=16000, new_freq=8000) + >>> resampled = resampler(signal) + >>> signal.shape + torch.Size([1, 52173]) + >>> resampled.shape + torch.Size([1, 26087]) + """ + + def __init__(self, orig_freq=16000, new_freq=16000, *args, **kwargs): + super().__init__() + + self.orig_freq = orig_freq + self.new_freq = new_freq + + self.resampler = torchaudio.transforms.Resample( + orig_freq=orig_freq, new_freq=new_freq, *args, **kwargs + ) + + def forward(self, waveforms): + """ + Arguments + --------- + waveforms : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]`. + """ + + # Don't do anything if the frequencies are the same + if self.orig_freq == self.new_freq: + return waveforms + + unsqueezed = False + if len(waveforms.shape) == 2: + waveforms = waveforms.unsqueeze(1) + unsqueezed = True + elif len(waveforms.shape) == 3: + waveforms = waveforms.transpose(1, 2) + else: + raise ValueError("Input must be 2 or 3 dimensions") + + # If necessary, migrate the resampler to the current device, for + # backwards compat with scripts that do not call `resampler.to()` + # themselves. + # Please do not reuse the sample resampler for tensors that live on + # different devices, though. + self.resampler.to(waveforms.device) # in-place + + # Do resampling + resampled_waveform = self.resampler(waveforms) + + if unsqueezed: + resampled_waveform = resampled_waveform.squeeze(1) + else: + resampled_waveform = resampled_waveform.transpose(1, 2) + + return resampled_waveform + + +class DropFreq(torch.nn.Module): + """This class drops a random frequency from the signal. + + The purpose of this class is to teach models to learn to rely on all parts + of the signal, not just a few frequency bands. + + Arguments + --------- + drop_freq_low : float + The low end of frequencies that can be dropped, + as a fraction of the sampling rate / 2. + drop_freq_high : float + The high end of frequencies that can be + dropped, as a fraction of the sampling rate / 2. + drop_freq_count_low : int + The low end of number of frequencies that could be dropped. + drop_freq_count_high : int + The high end of number of frequencies that could be dropped. + drop_freq_width : float + The width of the frequency band to drop, as + a fraction of the sampling_rate / 2. + epsilon : float + A small positive value to prevent issues such as filtering 0 Hz, + division by zero, or other numerical instabilities. This value sets + the absolute minimum for normalized frequencies used in the filter. + The default value is 1e-12. + + Example + ------- + >>> from speechbrain.dataio.dataio import read_audio + >>> dropper = DropFreq() + >>> signal = read_audio("tests/samples/single-mic/example1.wav") + >>> dropped_signal = dropper(signal.unsqueeze(0)) + """ + + def __init__( + self, + drop_freq_low=1e-14, + drop_freq_high=1, + drop_freq_count_low=1, + drop_freq_count_high=3, + drop_freq_width=0.05, + epsilon=1e-12, + ): + super().__init__() + self.drop_freq_low = drop_freq_low + self.drop_freq_high = drop_freq_high + self.drop_freq_count_low = drop_freq_count_low + self.drop_freq_count_high = drop_freq_count_high + self.drop_freq_width = drop_freq_width + self.epsilon = epsilon + + def forward(self, waveforms): + """ + Arguments + --------- + waveforms : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]`. + """ + + # Don't drop (return early) 1-`drop_prob` portion of the batches + dropped_waveform = waveforms.clone() + + # Add channels dimension + if len(waveforms.shape) == 2: + dropped_waveform = dropped_waveform.unsqueeze(-1) + + # Pick number of frequencies to drop + drop_count = torch.randint( + low=self.drop_freq_count_low, + high=self.drop_freq_count_high + 1, + size=(1,), + ) + + # Pick a frequency to drop + drop_range = self.drop_freq_high - self.drop_freq_low + drop_frequency = ( + torch.rand(drop_count) * drop_range + self.drop_freq_low + ).clamp(min=self.epsilon) + # Filter parameters + filter_length = 101 + pad = filter_length // 2 + + # Start with delta function + drop_filter = torch.zeros(1, filter_length, 1, device=waveforms.device) + drop_filter[0, pad, 0] = 1 + + # Subtract each frequency + for frequency in drop_frequency: + notch_kernel = notch_filter( + frequency, filter_length, self.drop_freq_width + ).to(waveforms.device) + drop_filter = convolve1d(drop_filter, notch_kernel, pad) + + # Manage multiple channels + if len(waveforms.shape) == 3: + dropped_waveform = dropped_waveform.reshape( + dropped_waveform.shape[0] * dropped_waveform.shape[2], + dropped_waveform.shape[1], + 1, + ) + + # Apply filter + dropped_waveform = convolve1d(dropped_waveform, drop_filter, pad) + + if len(waveforms.shape) == 3: + dropped_waveform = dropped_waveform.reshape( + waveforms.shape[0], waveforms.shape[1], waveforms.shape[2] + ) + + # Remove channels dimension if added + return dropped_waveform.squeeze(-1) + + +class DropChunk(torch.nn.Module): + """This class drops portions of the input signal. + + Using `DropChunk` as an augmentation strategy helps a models learn to rely + on all parts of the signal, since it can't expect a given part to be + present. + + Arguments + --------- + drop_length_low : int + The low end of lengths for which to set the + signal to zero, in samples. + drop_length_high : int + The high end of lengths for which to set the + signal to zero, in samples. + drop_count_low : int + The low end of number of times that the signal + can be dropped to zero. + drop_count_high : int + The high end of number of times that the signal + can be dropped to zero. + drop_start : int + The first index for which dropping will be allowed. + drop_end : int + The last index for which dropping will be allowed. + noise_factor : float + The factor relative to average amplitude of an utterance + to use for scaling the white noise inserted. 1 keeps + the average amplitude the same, while 0 inserts all 0's. + + Example + ------- + >>> from speechbrain.dataio.dataio import read_audio + >>> dropper = DropChunk(drop_start=100, drop_end=200, noise_factor=0.0) + >>> signal = read_audio("tests/samples/single-mic/example1.wav") + >>> signal = signal.unsqueeze(0) # [batch, time, channels] + >>> length = torch.ones(1) + >>> dropped_signal = dropper(signal, length) + >>> float(dropped_signal[:, 150]) + 0.0 + """ + + def __init__( + self, + drop_length_low=100, + drop_length_high=1000, + drop_count_low=1, + drop_count_high=3, + drop_start=0, + drop_end=None, + noise_factor=0.0, + ): + super().__init__() + self.drop_length_low = drop_length_low + self.drop_length_high = drop_length_high + self.drop_count_low = drop_count_low + self.drop_count_high = drop_count_high + self.drop_start = drop_start + self.drop_end = drop_end + self.noise_factor = noise_factor + + # Validate low < high + if drop_length_low > drop_length_high: + raise ValueError("Low limit must not be more than high limit") + if drop_count_low > drop_count_high: + raise ValueError("Low limit must not be more than high limit") + + # Make sure the length doesn't exceed end - start + if drop_end is not None and drop_end >= 0: + if drop_start > drop_end: + raise ValueError("Low limit must not be more than high limit") + + drop_range = drop_end - drop_start + self.drop_length_low = min(drop_length_low, drop_range) + self.drop_length_high = min(drop_length_high, drop_range) + + def forward(self, waveforms, lengths): + """ + Arguments + --------- + waveforms : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + lengths : torch.Tensor + Shape should be a single dimension, `[batch]`. + + Returns + ------- + Tensor of shape `[batch, time]` or + `[batch, time, channels]` + """ + + # Reading input list + lengths = (lengths * waveforms.size(1)).long() + batch_size = waveforms.size(0) + dropped_waveform = waveforms.clone() + + # Store original amplitude for computing white noise amplitude + clean_amplitude = compute_amplitude(waveforms, lengths.unsqueeze(1)) + + # Pick a number of times to drop + drop_times = torch.randint( + low=self.drop_count_low, + high=self.drop_count_high + 1, + size=(batch_size,), + ) + + # Iterate batch to set mask + for i in range(batch_size): + if drop_times[i] == 0: + continue + + # Pick lengths + length = torch.randint( + low=self.drop_length_low, + high=self.drop_length_high + 1, + size=(drop_times[i],), + ) + + # Compute range of starting locations + start_min = self.drop_start + if start_min < 0: + start_min += lengths[i] + start_max = self.drop_end + if start_max is None: + start_max = lengths[i] + if start_max < 0: + start_max += lengths[i] + start_max = max(0, start_max - length.max()) + + # Pick starting locations + start = torch.randint( + low=start_min, high=start_max + 1, size=(drop_times[i],) + ) + + end = start + length + + # Update waveform + if not self.noise_factor: + for j in range(drop_times[i]): + dropped_waveform[i, start[j] : end[j]] = 0.0 + else: + # Uniform distribution of -2 to +2 * avg amplitude should + # preserve the average for normalization + noise_max = 2 * clean_amplitude[i] * self.noise_factor + for j in range(drop_times[i]): + # zero-center the noise distribution + noise_vec = torch.rand(length[j], device=waveforms.device) + noise_vec = 2 * noise_max * noise_vec - noise_max + dropped_waveform[i, start[j] : end[j]] = noise_vec + + return dropped_waveform + + +class FastDropChunk(torch.nn.Module): + """This class drops portions of the input signal. The difference with + DropChunk is that in this case we pre-compute the dropping masks in the + first time the forward function is called. For all the other calls, we only + shuffle and apply them. This makes the code faster and more suitable for + data augmentation of large batches. + + It can be used only for fixed-length sequences. + + Arguments + --------- + drop_length_low : int + The low end of lengths for which to set the + signal to zero, in samples. + drop_length_high : int + The high end of lengths for which to set the + signal to zero, in samples. + drop_count_low : int + The low end of number of times that the signal + can be dropped to zero. + drop_count_high : int + The high end of number of times that the signal + can be dropped to zero. + drop_start : int + The first index for which dropping will be allowed. + drop_end : int + The last index for which dropping will be allowed. + n_masks : int + The number of precomputed masks. + + Example + ------- + >>> from speechbrain.dataio.dataio import read_audio + >>> dropper = FastDropChunk(drop_start=100, drop_end=200) + >>> signal = torch.rand(10, 250, 22) + >>> dropped_signal = dropper(signal) + """ + + def __init__( + self, + drop_length_low=100, + drop_length_high=1000, + drop_count_low=1, + drop_count_high=10, + drop_start=0, + drop_end=None, + n_masks=1000, + ): + super().__init__() + self.drop_length_low = drop_length_low + self.drop_length_high = drop_length_high + self.drop_count_low = drop_count_low + self.drop_count_high = drop_count_high + self.drop_start = drop_start + self.drop_end = drop_end + self.n_masks = n_masks + self.first = True + + # Validate low < high + if drop_length_low > drop_length_high: + raise ValueError("Low limit must not be more than high limit") + if drop_count_low > drop_count_high: + raise ValueError("Low limit must not be more than high limit") + + # Make sure the length doesn't exceed end - start + if drop_end is not None and drop_end >= 0: + if drop_start > drop_end: + raise ValueError("Low limit must not be more than high limit") + drop_range = drop_end - drop_start + self.drop_length_low = min(drop_length_low, drop_range) + self.drop_length_high = min(drop_length_high, drop_range) + + def initialize_masks(self, waveforms): + """ + Arguments + --------- + waveforms : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + `. + Returns + ------- + dropped_masks : torch.Tensor + Tensor of size `[n_masks, time]` with the dropped chunks. Dropped + regions are assigned to 0. + """ + + if self.n_masks < waveforms.shape[0]: + raise ValueError("n_mask cannot be smaller than the batch size") + + # Initializing the drop mask + dropped_masks = torch.ones( + [self.n_masks, self.sig_len], device=waveforms.device + ) + + # Pick a number of times to drop + drop_times = torch.randint( + low=self.drop_count_low, + high=self.drop_count_high + 1, + size=(self.n_masks,), + device=waveforms.device, + ) + + # Iterate batch to set mask + for i in range(self.n_masks): + if drop_times[i] == 0: + continue + + # Pick lengths + length = torch.randint( + low=self.drop_length_low, + high=self.drop_length_high + 1, + size=(drop_times[i],), + device=waveforms.device, + ) + + # Compute range of starting locations + start_min = self.drop_start + if start_min < 0: + start_min += self.sig_len + start_max = self.drop_end + if start_max is None: + start_max = self.sig_len + if start_max < 0: + start_max += self.sig_len + start_max = max(0, start_max - length.max()) + + # Pick starting locations + start = torch.randint( + low=start_min, + high=start_max + 1, + size=(drop_times[i],), + device=waveforms.device, + ) + + end = start + length + + # Update waveform + for j in range(drop_times[i]): + dropped_masks[i, start[j] : end[j]] = 0.0 + + return dropped_masks + + def forward(self, waveforms): + """ + Arguments + --------- + waveforms : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]` + """ + + dropped_waveforms = waveforms.clone() + + # Initialize the masks + if self.first: + self.sig_len = waveforms.shape[1] + self.dropped_masks = self.initialize_masks(waveforms) + self.first = False + + # Random Permutation + rand_perm = torch.randperm(self.dropped_masks.shape[0]) + self.dropped_masks = self.dropped_masks[rand_perm, :] + + # Random shift in time + rand_shifts = torch.randint(low=0, high=self.sig_len, size=(1,)) + self.dropped_masks = torch.roll( + self.dropped_masks, shifts=rand_shifts.item(), dims=1 + ) + + if len(waveforms.shape) == 3: + dropped_waveforms = dropped_waveforms * self.dropped_masks[ + 0 : waveforms.shape[0] + ].unsqueeze(2) + else: + dropped_waveforms = ( + dropped_waveforms * self.dropped_masks[0 : waveforms.shape[0]] + ) + + return dropped_waveforms + + +class DoClip(torch.nn.Module): + """This function mimics audio clipping by clamping the input tensor. + First, it normalizes the waveforms from -1 to -1. Then, clipping is applied. + Finally, the original amplitude is restored. + + Arguments + --------- + clip_low : float + The low end of amplitudes for which to clip the signal. + clip_high : float + The high end of amplitudes for which to clip the signal. + + Example + ------- + >>> from speechbrain.dataio.dataio import read_audio + >>> clipper = DoClip(clip_low=0.01, clip_high=0.01) + >>> signal = read_audio("tests/samples/single-mic/example1.wav") + >>> clipped_signal = clipper(signal.unsqueeze(0)) + """ + + def __init__(self, clip_low=0.5, clip_high=0.5): + super().__init__() + self.clip_low = clip_low + self.clip_high = clip_high + + def forward(self, waveforms): + """ + Arguments + --------- + waveforms : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]` + """ + + # Normalize the signal + abs_max, _ = torch.max(torch.abs(waveforms), dim=1, keepdim=True) + waveforms = waveforms / abs_max + + # Randomly select clip value + clipping_range = self.clip_high - self.clip_low + clip_value = ( + torch.rand(1, device=waveforms.device)[0] * clipping_range + + self.clip_low + ) + + # Apply clipping + clipped_waveform = waveforms.clamp(-clip_value, clip_value) + + # Restore original amplitude + clipped_waveform = clipped_waveform * abs_max / clip_value + + return clipped_waveform + + +class RandAmp(torch.nn.Module): + """This function multiples the signal by a random amplitude. First, the + signal is normalized to have amplitude between -1 and 1. Then it is + multiplied with a random number. + + Arguments + --------- + amp_low : float + The minimum amplitude multiplication factor. + amp_high : float + The maximum amplitude multiplication factor. + + Example + ------- + >>> from speechbrain.dataio.dataio import read_audio + >>> rand_amp = RandAmp(amp_low=0.25, amp_high=1.75) + >>> signal = read_audio("tests/samples/single-mic/example1.wav") + >>> output_signal = rand_amp(signal.unsqueeze(0)) + """ + + def __init__(self, amp_low=0.5, amp_high=1.5): + super().__init__() + self.amp_low = amp_low + self.amp_high = amp_high + + def forward(self, waveforms): + """ + Arguments + --------- + waveforms : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]` + """ + + # Normalize the signal + abs_max, _ = torch.max(torch.abs(waveforms), dim=1, keepdim=True) + waveforms = waveforms / abs_max + + # Pick a frequency to drop + rand_range = self.amp_high - self.amp_low + amp = ( + torch.rand(waveforms.shape[0], device=waveforms.device) * rand_range + + self.amp_low + ) + amp = amp.unsqueeze(1) + if len(waveforms.shape) == 3: + amp = amp.unsqueeze(2) + waveforms = waveforms * amp + + return waveforms + + +class ChannelDrop(torch.nn.Module): + """This function drops random channels in the multi-channel input waveform. + + Arguments + --------- + drop_rate : float + The channel dropout factor + + Example + ------- + >>> signal = torch.rand(4, 256, 8) + >>> ch_drop = ChannelDrop(drop_rate=0.5) + >>> output_signal = ch_drop(signal) + """ + + def __init__(self, drop_rate=0.1): + super().__init__() + self.drop_rate = drop_rate + + def forward(self, waveforms): + """ + Arguments + --------- + waveforms : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]` + """ + + # Pick a channel to drop + x = torch.rand(waveforms.shape[-1], device=waveforms.device) + channel_mask = x.ge(self.drop_rate) + waveforms = waveforms * channel_mask.unsqueeze(0).unsqueeze(1) + return waveforms + + +class ChannelSwap(torch.nn.Module): + """This function randomly swaps N channels. + + Arguments + --------- + min_swap : int + The minimum number of channels to swap. + max_swap : int + The maximum number of channels to swap. + + Example + ------- + >>> signal = torch.rand(4, 256, 8) + >>> ch_swap = ChannelSwap() + >>> output_signal = ch_swap(signal) + """ + + def __init__(self, min_swap=0, max_swap=0): + super().__init__() + self.min_swap = min_swap + self.max_swap = max_swap + + # Check arguments + if self.min_swap < 0: + raise ValueError("min_swap must be >= 0.") + if self.max_swap < 0: + raise ValueError("max_swap must be >= 0.") + if self.max_swap < self.min_swap: + raise ValueError("max_swap must be >= min_swap") + + def forward(self, waveforms): + """ + Arguments + --------- + waveforms : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]` + """ + + # Pick a frequency to drop + rand_perm1 = torch.randperm(waveforms.shape[-1]) + rand_perm2 = torch.randperm(waveforms.shape[-1]) + N_swaps = torch.randint( + low=self.min_swap, high=self.max_swap + 1, size=(1,) + ) + + if N_swaps < waveforms.shape[-1]: + for i in range(N_swaps): + store_channel = waveforms[:, :, rand_perm2[i]] + waveforms[:, :, rand_perm2[i]] = waveforms[:, :, rand_perm1[i]] + waveforms[:, :, rand_perm1[i]] = store_channel + else: + # Full swap + waveforms = waveforms[:, :, rand_perm1] + + return waveforms + + +class CutCat(torch.nn.Module): + """This function combines segments (with equal length in time) of the time series contained in the batch. + Proposed for EEG signals in https://doi.org/10.1016/j.neunet.2021.05.032. + + Arguments + --------- + min_num_segments : int + The number of segments to combine. + max_num_segments : int + The maximum number of segments to combine. Default is 10. + + Example + ------- + >>> signal = torch.ones((4, 256, 22)) * torch.arange(4).reshape( + ... ( + ... 4, + ... 1, + ... 1, + ... ) + ... ) + >>> cutcat = CutCat() + >>> output_signal = cutcat(signal) + """ + + def __init__(self, min_num_segments=2, max_num_segments=10): + super().__init__() + self.min_num_segments = min_num_segments + self.max_num_segments = max_num_segments + # Check arguments + if self.max_num_segments < self.min_num_segments: + raise ValueError("max_num_segments must be >= min_num_segments") + + def forward(self, waveforms): + """ + Arguments + --------- + waveforms : torch.Tensor + Shape should be `[batch, time]` or `[batch, time, channels]`. + + Returns + ------- + Tensor of shape `[batch, time]` or `[batch, time, channels]` + """ + if ( + waveforms.shape[0] > 1 + ): # only if there are at least 2 examples in batch + # rolling waveforms to point to segments of other examples in batch + waveforms_rolled = torch.roll(waveforms, shifts=1, dims=0) + # picking number of segments to use + num_segments = torch.randint( + low=self.min_num_segments, + high=self.max_num_segments + 1, + size=(1,), + ) + # index of cuts (both starts and stops) + idx_cut = torch.linspace( + 0, waveforms.shape[1], num_segments.item() + 1, dtype=torch.int + ) + for i in range(idx_cut.shape[0] - 1): + # half of segments from other examples in batch + if i % 2 == 1: + start = idx_cut[i] + stop = idx_cut[i + 1] + waveforms[:, start:stop, ...] = waveforms_rolled[ + :, start:stop, ... + ] + + return waveforms + + +def pink_noise_like(waveforms, alpha_low=1.0, alpha_high=1.0, sample_rate=50): + """Creates a sequence of pink noise (also known as 1/f). The pink noise + is obtained by multiplying the spectrum of a white noise sequence by a + factor (1/f^alpha). + The alpha factor controls the decrease factor in the frequency domain + (alpha=0 adds white noise, alpha>>0 adds low frequency noise). It is + randomly sampled between alpha_low and alpha_high. With negative alpha this + function generates blue noise. + + Arguments + --------- + waveforms : torch.Tensor + The original waveform. It is just used to infer the shape. + alpha_low : float + The minimum value for the alpha spectral smoothing factor. + alpha_high : float + The maximum value for the alpha spectral smoothing factor. + sample_rate : float + The sample rate of the original signal. + + Returns + ------- + pink_noise : torch.Tensor + Pink noise in the shape of the input tensor. + + Example + ------- + >>> waveforms = torch.randn(4, 257, 10) + >>> noise = pink_noise_like(waveforms) + >>> noise.shape + torch.Size([4, 257, 10]) + """ + # Sampling white noise (flat spectrum) + white_noise = torch.randn_like(waveforms) + + # Computing the fft of the input white noise + white_noise_fft = torch.fft.fft(white_noise, dim=1) + + # Sampling the spectral smoothing factor + rand_range = alpha_high - alpha_low + alpha = ( + torch.rand(waveforms.shape[0], device=waveforms.device) * rand_range + + alpha_low + ) + + # preparing the spectral mask (1/f^alpha) + f = torch.linspace( + 0, + sample_rate / 2, + int(white_noise.shape[1] / 2), + device=waveforms.device, + ) + spectral_mask = 1 / torch.pow(f.unsqueeze(0), alpha.unsqueeze(1)) + + # Avoid inf due to 1/0 division at f=0 + spectral_mask[:, 0] = spectral_mask[:, 1] + + # Mask for the upper part of the spectrum (f > sample_rate/2) + spectral_mask_up = torch.flip(spectral_mask, dims=(1,)) + + # Managing odd/even sequences + if white_noise.shape[1] % 2: + mid_element = spectral_mask[ + :, int(white_noise.shape[1] / 2) - 1 + ].unsqueeze(1) + spectral_mask = torch.cat( + [spectral_mask, mid_element, spectral_mask_up], dim=1 + ) + else: + spectral_mask = torch.cat([spectral_mask, spectral_mask_up], dim=1) + + # Managing multi-channel inputs + if len(white_noise.shape) == 3: + spectral_mask = spectral_mask.unsqueeze(2) + + # Spectral masking + pink_noise_fft = white_noise_fft * spectral_mask + + # Return to the time-domain + pink_noise = torch.fft.ifft(pink_noise_fft, dim=1).real + return pink_noise + + +class DropBitResolution(torch.nn.Module): + """ + This class transforms a float32 tensor into a lower resolution one + (e.g., int16, int8, float16) and then converts it back to a float32. + This process loses information and can be used for data augmentation. + + Arguments: + --------- + target_dtype: str + One of "int16", "int8", "float16". If "random", the bit resolution + is randomly selected among the options listed above. + + Example: + >>> dropper = DropBitResolution() + >>> signal = torch.rand(4, 16000) + >>> signal_dropped = dropper(signal) + """ + + def __init__(self, target_dtype="random"): + super().__init__() + + self.target_dtype = target_dtype + self.bit_depths = { + "int16": (16, torch.int16), + "int8": (8, torch.int8), + "float16": (16, torch.float16), + } + + if ( + self.target_dtype != "random" + and self.target_dtype not in self.bit_depths + ): + raise ValueError( + f"target_dtype must be one of {list(self.bit_depths.keys())}" + ) + + def forward(self, float32_tensor): + """ + Arguments: + --------- + float32_tensor: torch.Tensor + Float32 tensor with shape `[batch, time]` or `[batch, time, channels]`. + + Returns: + --------- + torch.Tensor + Tensor of shape `[batch, time]` or `[batch, time, channels]` (Float32) + """ + + if self.target_dtype == "random": + random_key = random.choice(list(self.bit_depths.keys())) + bit, target_dtype = self.bit_depths[random_key] + else: + bit, target_dtype = self.bit_depths[self.target_dtype] + + # Define a scale factor to map the float32 range to the target bit depth + if target_dtype != torch.float16: + scale_factor = (2 ** (bit - 1) - 1) / float32_tensor.abs().max() + quantized_tensor = (float32_tensor * scale_factor).to(target_dtype) + else: + quantized_tensor = float32_tensor.half() + scale_factor = 1 + + # To dequantize and recover the original float32 values + dequantized_tensor = quantized_tensor.to(torch.float32) / scale_factor + return dequantized_tensor + + +class SignFlip(torch.nn.Module): + """Flip the sign of a signal. + + This module negates all the values in a tensor with a given probability. + If the sign is not flipped, the original signal is returned + unchanged. This technique is outlined in the paper: + "CADDA: Class-wise Automatic Differentiable Data Augmentation for EEG Signals" + https://arxiv.org/pdf/2106.13695 + + Arguments + --------- + flip_prob : float + The probability with which to flip the sign of the signal. Default is 0.5. + + Example + ------- + >>> import torch + >>> x = torch.tensor([1, 2, 3, 4, 5]) + >>> flip = SignFlip(flip_prob=1) # 100% chance to flip sign + >>> flip(x) + tensor([-1, -2, -3, -4, -5]) + """ + + def __init__(self, flip_prob=0.5): + super().__init__() + self.flip_prob = flip_prob + + def forward(self, waveform): + """ + Arguments + --------- + waveform : torch.Tensor + Input tensor representaing waveform, shape does not matter. + + Returns + ------- + torch.Tensor + The output tensor with same shape as the input, where the + sign of all values in the tensor has been flipped with + probability `flip_prob`. + + """ + + # Flip sign with `flip_prob` probability. + if torch.rand(1).item() < self.flip_prob: + return -waveform + + return waveform diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/core.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/core.py new file mode 100644 index 000000000..55286c718 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/core.py @@ -0,0 +1,1489 @@ +"""Core SpeechBrain code for running experiments. + +Authors + * Peter Plantinga 2020, 2023 + * Abdel Heba 2020 + * Mirco Ravanelli 2020 + * Aku Rouhe 2021 + * Andreas Nautsch 2022 + * Sylvain de Langen 2023 + * Adel Moumen 2023, 2024 +""" + +import inspect +import logging +import os +import pathlib +import shutil +import sys +import tempfile +import time +import warnings +from contextlib import contextmanager +from datetime import date +from enum import Enum, auto +from types import SimpleNamespace + +import torch +import yaml +from hyperpyyaml import resolve_references +from packaging import version +from torch.nn import ( + DataParallel as DP, + SyncBatchNorm, +) +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import DataLoader, DistributedSampler, IterableDataset +from tqdm import tqdm + +import speechbrain as sb +from speechbrain.dataio.dataloader import LoopedLoader, SaveableDataLoader +from speechbrain.dataio.sampler import ( + DistributedSamplerWrapper, + ReproducibleRandomSampler, +) +from speechbrain.utils.autocast import AMPConfig, TorchAutocast +from speechbrain.utils.distributed import is_distributed_initialized +from speechbrain.utils.logger import get_logger +from speechbrain.utils.optimizers import rm_vector_weight_decay +from speechbrain.utils.profiling import prepare_profiler +from speechbrain.utils.run_opts import RunOptions + +sb.utils.quirks.apply_quirks() + +logger = get_logger(__name__) +DEFAULT_LOG_CONFIG = os.path.dirname(os.path.abspath(__file__)) +DEFAULT_LOG_CONFIG = os.path.join(DEFAULT_LOG_CONFIG, "log-config.yaml") +INTRA_EPOCH_CKPT_FLAG = "brain_intra_epoch_ckpt" +PYTHON_VERSION_MAJOR = 3 +PYTHON_VERSION_MINOR = 8 + + +def create_experiment_directory( + experiment_directory, + hyperparams_to_save=None, + overrides={}, + log_config=DEFAULT_LOG_CONFIG, + save_env_desc=True, +): + """Create the output folder and relevant experimental files. + + Arguments + --------- + experiment_directory : str + The place where the experiment directory should be created. + hyperparams_to_save : str + A filename of a yaml file representing the parameters for this + experiment. If passed, references are resolved, and the result is + written to a file in the experiment directory called "hyperparams.yaml". + overrides : dict + A mapping of replacements made in the yaml file, to save in yaml. + log_config : str + A yaml filename containing configuration options for the logger. + save_env_desc : bool + If True, an environment state description is saved to the experiment + directory, in a file called env.log in the experiment directory. + """ + try: + # all writing command must be done with the main_process + if sb.utils.distributed.if_main_process(): + if not os.path.isdir(experiment_directory): + os.makedirs(experiment_directory) + + # Write the parameters file + if hyperparams_to_save is not None: + hyperparams_filename = os.path.join( + experiment_directory, "hyperparams.yaml" + ) + with open(hyperparams_to_save, encoding="utf-8") as f: + resolved_yaml = resolve_references(f, overrides) + with open(hyperparams_filename, "w", encoding="utf-8") as w: + print("# Generated %s from:" % date.today(), file=w) + print("# %s" % os.path.abspath(hyperparams_to_save), file=w) + print("# yamllint disable", file=w) + shutil.copyfileobj(resolved_yaml, w) + + # Copy executing file to output directory + module = inspect.getmodule(inspect.currentframe().f_back) + if module is not None: + callingfile = os.path.realpath(module.__file__) + shutil.copy(callingfile, experiment_directory) + + # Log exceptions to output automatically + log_file = os.path.join(experiment_directory, "log.txt") + logger_overrides = { + "handlers": {"file_handler": {"filename": log_file}} + } + sb.utils.logger.setup_logging(log_config, logger_overrides) + sys.excepthook = _logging_excepthook + + # Log quirks again so that it makes it to the log file. + # Quirks are applied way earlier, before logging is properly setup, + # so this gives a chance to the user to see them, lowering surprise. + sb.utils.quirks.log_applied_quirks() + + # Log beginning of experiment! + logger.info("Beginning experiment!") + logger.info(f"Experiment folder: {experiment_directory}") + + # Save system description: + if save_env_desc: + description_str = sb.utils.logger.get_environment_description() + with open( + os.path.join(experiment_directory, "env.log"), + "w", + encoding="utf-8", + ) as fo: + fo.write(description_str) + finally: + # wait for main_process if ddp is used + sb.utils.distributed.ddp_barrier() + + +def _logging_excepthook(exc_type, exc_value, exc_traceback): + """Interrupt exception raising to log the error.""" + logger.error("Exception:", exc_info=(exc_type, exc_value, exc_traceback)) + + +class Stage(Enum): + """Simple enum to track stage of experiments.""" + + TRAIN = auto() + VALID = auto() + TEST = auto() + + +@sb.utils.checkpoints.register_checkpoint_hooks +class Brain: + """Brain class abstracts away the details of data loops. + + The primary purpose of the `Brain` class is the implementation of + the ``fit()`` method, which iterates epochs and datasets for the + purpose of "fitting" a set of modules to a set of data. + + In order to use the ``fit()`` method, one should sub-class the ``Brain`` + class and override any methods for which the default behavior does not + match the use case. For a simple use case (e.g., training a single model + with a single dataset) the only methods that need to be overridden are: + + * ``compute_forward()`` + * ``compute_objectives()`` + + The example below illustrates how overriding these two methods is done. + + For more complicated use cases, such as multiple modules that need to + be updated, the following methods can be overridden: + + * ``fit_batch()`` + * ``evaluate_batch()`` + + Arguments + --------- + modules : dict[str, torch.nn.Module] + These modules are passed to the optimizer by default if they have + trainable parameters, and will have ``train()``/``eval()`` called on them. + opt_class : Optional[Type[torch.optim]] + A torch optimizer constructor that takes only the list of + parameters (e.g. a lambda or partial function definition). By default, + this will be passed all modules in ``modules`` at the + beginning of the ``fit()`` method. This behavior can be changed + by overriding the ``configure_optimizers()`` method. + hparams : Optional[dict] + Each key:value pair should consist of a string key and a hyperparameter + that is used within the overridden methods. These will + be accessible via an ``hparams`` attribute, using "dot" notation: + e.g., self.hparams.model(x). + run_opts : Optional[Union[RunOptions, dict]] + A set of options to change the runtime environment, see ``RunOptions`` for a list. + Typically in a script this comes from ``speechbrain.parse_args``, an alias + for ``RunOptions.from_command_line_args``. If an option is not defined here + (keep in mind that `parse_args` will inject some options by default), + then the option is also searched for in hparams (by key). + checkpointer : Optional[speechbrain.utils.checkpoints.Checkpointer] + By default, this will be used to load checkpoints, and will have the + optimizer added to continue training if interrupted. + + Example + ------- + >>> from torch.optim import SGD + >>> class SimpleBrain(Brain): + ... def compute_forward(self, batch, stage): + ... return self.modules.model(batch[0] * self.hparams.scalar) + ... + ... def compute_objectives(self, predictions, batch, stage): + ... return torch.nn.functional.l1_loss(predictions, batch[0]) + >>> model = torch.nn.Linear(in_features=10, out_features=10) + >>> brain = SimpleBrain( + ... modules={"model": model}, + ... opt_class=lambda x: SGD(x, lr=0.1), + ... hparams={"scalar": 5}, + ... run_opts={"device": "cpu"}, + ... ) + >>> brain.fit(range(1), ([torch.rand(10, 10), torch.rand(10, 10)],)) + """ + + def __init__( # noqa: C901 + self, + modules=None, + opt_class=None, + hparams=None, + run_opts=None, + checkpointer=None, + ): + self.optimizers_dict = None + self.opt_class = opt_class + self.checkpointer = checkpointer + if isinstance(run_opts, dict): + run_opts = RunOptions.from_dictionary(run_opts) + + # Check which options have been overridden. Order of priority + # is lowest: default < hparams < run_opts: highest + run_opt_defaults = RunOptions() + for arg, default in run_opt_defaults.as_dict().items(): + if run_opts is not None and arg in run_opts.overridden_args: + if hparams is not None and arg in hparams: + logger.info( + f"{arg} which is specified in hparams was overridden " + + f"by command line input to: {run_opts[arg]}" + ) + setattr(self, arg, run_opts[arg]) + + # If any arg from run_opt_defaults exist in hparams and + # not in "run_opts" which is likely from command line + elif hparams is not None and arg in hparams: + logger.info(f"Run option {arg} from hparams is used") + setattr(self, arg, hparams[arg]) + else: + setattr(self, arg, default) + + # Check Python version + if not ( + sys.version_info.major == PYTHON_VERSION_MAJOR + and sys.version_info.minor >= PYTHON_VERSION_MINOR + ): + logger.warning( + "Detected Python " + + str(sys.version_info.major) + + "." + + str(sys.version_info.minor) + + ". We suggest using SpeechBrain with Python >=" + + str(PYTHON_VERSION_MAJOR) + + "." + + str(PYTHON_VERSION_MINOR) + ) + + # Assume `torchrun` was used if `RANK` and `LOCAL_RANK` are set + self.distributed_launch = ( + os.environ.get("RANK") is not None + and os.environ.get("LOCAL_RANK") is not None + ) + + if self.data_parallel_backend and self.distributed_launch: + raise ValueError( + "To use data_parallel backend, start your script with:\n\t" + "python experiment.py hyperparams.yaml " + "--data_parallel_backend=True\n" + "To use DDP backend, start your script with:\n\t" + "torchrun [args] experiment.py hyperparams.yaml" + ) + + if self.ckpt_interval_minutes > 0 and self.ckpt_interval_steps > 0: + sys.exit( + "The options `ckpt_interval_minutes` and `ckpt_interval_steps` " + "are mutually exclusive. " + "Please keep only one active per experiment run." + ) + + # If device was not specified, then make best guess + if self.device is None: + self.device = sb.utils.distributed.infer_device() + + # Set device type based on device string + if self.device == "cpu": + self.device_type = "cpu" + elif "cuda" in self.device: + self.device_type = "cuda" + + # Set cuda device based on device string + try: + _, device_index = self.device.split(":") + torch.cuda.set_device(int(device_index)) + except ValueError: + torch.cuda.set_device(0) + + # Checking that DataParallel use the right number of GPU + if self.data_parallel_backend and torch.cuda.device_count() == 0: + raise ValueError("You must have at least 1 GPU to use DataParallel") + + # Put modules on the right device, accessible with dot notation + self.modules = torch.nn.ModuleDict(modules).to(self.device) + + # The next line ensures that both tensors marked as parameters and standard tensors, + # such as those used in InputNormalization, are placed on the right device. + for module in self.modules: + if hasattr(self.modules[module], "to"): + self.modules[module] = self.modules[module].to(self.device) + + # Make hyperparams available with dot notation too + if hparams is not None: + self.hparams = SimpleNamespace(**hparams) + + # Checkpointer should point at a temporary directory in debug mode + if ( + self.debug + and not self.debug_persistently + and self.checkpointer is not None + and hasattr(self.checkpointer, "checkpoints_dir") + ): + tempdir = tempfile.TemporaryDirectory() + logger.info( + "Since debug mode is active, switching checkpointer " + f"output to temporary directory: {tempdir.name}" + ) + self.checkpointer.checkpoints_dir = pathlib.Path(tempdir.name) + + # Keep reference to tempdir as long as checkpointer exists + self.checkpointer.tempdir = tempdir + + # Sampler should be handled by `make_dataloader` + # or if you provide a DataLoader directly, you can set + # this.train_sampler = your_sampler + # to have your_sampler.set_epoch() called on each epoch. + self.train_sampler = None + + if self.auto_mix_prec: + logger.warning( + "The option `--auto_mix_prec` is deprecated and will be removed in the future. " + "Please use `--precision=fp16` instead." + ) + self.precision = "fp16" + + if self.bfloat16_mix_prec: + logger.warning( + "The option `--bfloat16_mix_prec` is deprecated and will be removed in the future. " + "Please use `--precision=bf16` instead." + ) + self.precision = "bf16" + + if self.device_type == "cpu" and ( + self.precision == "fp16" or self.eval_precision == "fp16" + ): + raise ValueError( + "The option `--precision` or `--eval_precision` is set to fp16. " + "This option is not yet supported on CPU. " + "Please use `--precision=bf16` or `--eval_precision=bf16` instead " + "to enable mixed precision on CPU." + ) + + gradscaler_enabled = ( + self.precision == "fp16" and self.device_type == "cuda" + ) + if self.skip_nonfinite_grads and gradscaler_enabled: + logger.warning( + "The option `skip_nonfinite_grads` will be ignored " + "because GradScaler is enabled and will automatically " + "skip nonfinite gradients." + ) + + logger.info(f"Gradscaler enabled: `{gradscaler_enabled}`") + logger.info(f"Using training precision: `--precision={self.precision}`") + logger.info( + f"Using evaluation precision: `--eval_precision={self.eval_precision}`" + ) + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled) + else: + self.scaler = torch.GradScaler( + self.device, enabled=gradscaler_enabled + ) + + train_dtype = AMPConfig.from_name(self.precision).dtype + self.training_ctx = TorchAutocast( + device_type=self.device_type, dtype=train_dtype + ) + eval_dtype = AMPConfig.from_name(self.eval_precision).dtype + self.evaluation_ctx = TorchAutocast( + device_type=self.device_type, dtype=eval_dtype + ) + if gradscaler_enabled and self.checkpointer is not None: + self.checkpointer.add_recoverable( + "scaler", self.scaler, optional_load=True + ) + + # List parameter count for the user + self.print_trainable_parameters() + + if self.distributed_launch: + self.rank = int(os.environ["RANK"]) + if not is_distributed_initialized(): + if self.rank > 0: + raise ValueError( + " ================ WARNING ===============" + "Please add sb.ddp_init_group() into your exp.py" + "To use DDP backend, start your script with:\n\t" + "torchrun [args] experiment.py hyperparams.yaml" + ) + else: + logger.warning( + "To use DDP, please add " + "sb.utils.distributed.ddp_init_group() into your exp.py" + ) + logger.info( + "Only the main process is alive, " + "all other subprocess were killed." + ) + + # Prepare iterating variables + self.avg_train_loss = 0.0 + self.step = 0 + self.optimizer_step = 0 + + # Add this class to the checkpointer for intra-epoch checkpoints + if self.checkpointer is not None: + self.checkpointer.add_recoverable("brain", self) + + # Force default color for tqdm progressbar + if not self.tqdm_colored_bar: + self.tqdm_barcolor = dict.fromkeys(self.tqdm_barcolor, "") + + # Profiler setup + self.profiler = None + if self.profile_training: + logger.info("Pytorch profiler has been activated.") + self.tot_prof_steps = (self.profile_steps + self.profile_warmup) - 1 + self.profiler = prepare_profiler( + self.profile_warmup, + self.profile_steps, + self.hparams.output_folder, + ) + + self.raw_modules = ( + self.modules.module + if hasattr(self.modules, "module") + else self.modules + ) + + def print_trainable_parameters(self): + """Prints the number of trainable parameters in the model.""" + total_trainable_params = 0 + total_parameters = 0 + for parameter in self.modules.parameters(): + total_parameters += parameter.numel() + if parameter.requires_grad: + total_trainable_params += parameter.numel() + class_name = self.__class__.__name__ + if total_parameters == 0: + logger.warning("The model has no parameters!") + logger.info( + f"{class_name} Model Statistics:\n" + f"* Total Number of Trainable Parameters: {total_trainable_params}\n" + f"* Total Number of Parameters: {total_parameters}\n" + f"* Trainable Parameters represent {0:.2f}% of the total size." + ) + elif total_trainable_params == 0: + logger.warning("The model has no trainable parameters!") + formatted_total_params = sb.utils.logger.format_order_of_magnitude( + total_parameters + ) + logger.info( + f"{class_name} Model Statistics:\n" + f"* Total Number of Trainable Parameters: {total_trainable_params}\n" + f"* Total Number of Parameters: {formatted_total_params}\n" + f"* Trainable Parameters represent {0:.4f}% of the total size." + ) + else: + percentage_trainable = ( + 100 * total_trainable_params / total_parameters + ) + formatted_trainable_params = ( + sb.utils.logger.format_order_of_magnitude( + total_trainable_params + ) + ) + formatted_total_params = sb.utils.logger.format_order_of_magnitude( + total_parameters + ) + logger.info( + f"{class_name} Model Statistics:\n" + f"* Total Number of Trainable Parameters: {formatted_trainable_params}\n" + f"* Total Number of Parameters: {formatted_total_params}\n" + f"* Trainable Parameters represent {percentage_trainable:.4f}% of the total size." + ) + + def compute_forward(self, batch, stage): + """Forward pass, to be overridden by sub-classes. + + Arguments + --------- + batch : torch.Tensor or tensors + An element from the dataloader, including inputs for processing. + stage : Stage + The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST + + Returns + ------- + torch.Tensor or torch.Tensors + The outputs after all processing is complete. + Directly passed to ``compute_objectives()``. + """ + raise NotImplementedError + return + + def compute_objectives(self, predictions, batch, stage): + """Compute loss, to be overridden by sub-classes. + + Arguments + --------- + predictions : torch.Tensor or torch.Tensors + The output tensor or tensors to evaluate. + Comes directly from ``compute_forward()``. + batch : torch.Tensor or tensors + An element from the dataloader, including targets for comparison. + stage : Stage + The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST + + Returns + ------- + loss : torch.Tensor + A tensor with the computed loss. + """ + raise NotImplementedError + return + + def on_stage_start(self, stage, epoch=None): + """Gets called when a stage starts. + + Useful for defining class variables used during the stage. + + Arguments + --------- + stage : Stage + The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST + epoch : int + The current epoch count. + """ + pass + + def on_stage_end(self, stage, stage_loss, epoch=None): + """Gets called at the end of a stage. + + Useful for computing stage statistics, saving checkpoints, etc. + + Arguments + --------- + stage : Stage + The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST + stage_loss : float + The average loss over the completed stage. + epoch : int + The current epoch count. + """ + pass + + def make_dataloader( + self, dataset, stage, ckpt_prefix="dataloader-", **loader_kwargs + ): + """Creates DataLoaders for Datasets. + + This is used by ``fit()`` and ``evaluate()`` if they just receive + Datasets. + + Alternatively, this can be called from outside the Brain subclass. + In that case, the DataLoader should be passed to ``fit()`` in place + of the dataset. + + The Stage.TRAIN DataLoader is handled specially. It has extra args for + shuffle and drop_last. In DDP a DistributedSampler is created (unless + the dataset is an IterableDataset). + + NOTE + ---- + Some important DataLoader arguments are passed via **loader_kwargs, + e.g., batch_size, num_workers, pin_memory. + + NOTE + ---- + By default, ``evaluate()`` specifies ckpt_prefix=None to stop the test + DataLoader being added to the checkpointer. If you need to add a + recoverable after saving checkpoints (e.g., at test time, after + checkpointing the training), and still be able to recover reasonably, + you should probably specify ``allow_partial_load=True``. + + Arguments + --------- + dataset : Dataset + A set of data to use to create data loader. If the Dataset is a + DynamicItemDataset, PaddedBatch is used as the default collate_fn, + unless specified in loader_kwargs. + stage : Stage + The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST + ckpt_prefix : str, None + Prefix to use for SaveableDataLoader Checkpoint name. The Stage + name is added to this to create the full key. Set to None to not + save the DataLoader. + **loader_kwargs : dict + Additional keyword arguments to the DataLoader. + E.g., batch_size, num_workers, pin_memory. + + Returns + ------- + DataLoader for the input dataset + """ + # TRAIN stage is handled specially. + if stage == sb.Stage.TRAIN: + loader_kwargs = self._train_loader_specifics(dataset, loader_kwargs) + # This commented-out code block is useful when one can ensure + # metric reporting is DDP-valid for VALID & EVAL datasets. + # elif self.distributed_launch: + # loader_kwargs = sb.dataio.dataloader.distributed_loader_specifics( + # self.distributed_launch, self.rank, dataset, loader_kwargs + # ) + dataloader = sb.dataio.dataloader.make_dataloader( + dataset, **loader_kwargs + ) + + if ( + self.checkpointer is not None + and ckpt_prefix is not None + and ( + isinstance(dataloader, SaveableDataLoader) + or isinstance(dataloader, LoopedLoader) + ) + ): + ckpt_key = ckpt_prefix + stage.name + self.checkpointer.add_recoverable(ckpt_key, dataloader) + return dataloader + + def _train_loader_specifics(self, dataset, loader_kwargs): + sampler = loader_kwargs.get("sampler", None) + # Shuffling should really only matter for the train stage. Shuffling + # will also lead to more padding in batches if the order was otherwise + # sorted by length. + shuffle = loader_kwargs.get("shuffle", False) + if shuffle and not self.distributed_launch: + if sampler is not None: + raise ValueError( + "Cannot specify both shuffle=True" + "and a sampler in loader_kwargs" + ) + seed = os.environ.get("SB_GLOBAL_SEED", 563375142) + sampler = ReproducibleRandomSampler(dataset, seed=seed) + self.train_sampler = sampler + loader_kwargs["sampler"] = self.train_sampler + # Delete the shuffle flag, since you cannot specify both a sampler and + # shuffling: + del loader_kwargs["shuffle"] + + # Possibly make a DistributedSampler or a wrapper for some other sampler + if self.distributed_launch and not isinstance(dataset, IterableDataset): + # sort or not + if hasattr(self.hparams, "sorting"): + shuffle_ddp = ( + self.hparams.sorting == "random" + ) # False if 'ascending' or 'descending' + else: + shuffle_ddp = True + + drop_last = loader_kwargs.get("drop_last", False) + # num_replicas arg is equal to world_size + # and retrieved automatically within + # DistributedSampler obj. + if sampler is not None: + self.train_sampler = DistributedSamplerWrapper( + sampler, + rank=self.rank, + drop_last=drop_last, + shuffle=shuffle, + ) + + # with DistributedSamplerWrapper, one must disable shuffling for dataloader + loader_kwargs["shuffle"] = False + loader_kwargs["sampler"] = self.train_sampler + elif loader_kwargs.get("batch_sampler") is None: + # no sampler and batch-sampler + self.train_sampler = DistributedSampler( + dataset, + rank=self.rank, + shuffle=shuffle_ddp, + drop_last=drop_last, + ) + + # with DistributedSamplerWrapper, one must disable shuffling for dataloader + loader_kwargs["shuffle"] = False + loader_kwargs["sampler"] = self.train_sampler + else: # batch_sampler was specified + self.train_sampler = DistributedSamplerWrapper( + loader_kwargs.get("batch_sampler", None), + rank=self.rank, + shuffle=shuffle_ddp, + ) + loader_kwargs["batch_sampler"] = self.train_sampler + elif self.distributed_launch and isinstance(dataset, IterableDataset): + logger.warning( + "Cannot automatically solve distributed sampling " + "for IterableDataset." + ) + return loader_kwargs + + def on_fit_start(self): + """Gets called at the beginning of ``fit()``, on multiple processes + if ``distributed_count > 0`` and backend is ddp. + + Default implementation compiles the jit modules, initializes + optimizers, and loads the latest checkpoint to resume training. + """ + # Run this *after* starting all processes since jit/compiled modules + # cannot be pickled. + self._compile() + + # Wrap modules with parallel backend after jit + self._wrap_distributed() + + # Initialize optimizers after parameters are configured + self.init_optimizers() + + # Load latest checkpoint to resume training if interrupted + if self.checkpointer is not None: + self.checkpointer.recover_if_possible() + + def init_optimizers(self): + """Called during ``on_fit_start()``, initialize optimizers + after parameters are fully configured (e.g. DDP, jit). + + The default implementation of this method depends on an optimizer + class being passed at initialization that takes only a list + of parameters (e.g., a lambda or a partial function definition). + This creates a single optimizer that optimizes all trainable params. + + Override this class if there are multiple optimizers. + """ + + all_params = self.modules.parameters() + + if self.opt_class is not None: + if self.remove_vector_weight_decay: + all_params = rm_vector_weight_decay(self.modules) + + self.optimizer = self.opt_class(all_params) + + self.optimizers_dict = {"opt_class": self.optimizer} + + if self.checkpointer is not None: + self.checkpointer.add_recoverable("optimizer", self.optimizer) + else: + logger.info( + "No `opt_class` was provided to this Brain class, " + "skipping optimizer initialization." + ) + + def zero_grad(self, set_to_none=False): + """Sets the gradients of all optimized ``torch.Tensor``s to zero + if ``set_to_none=False`` (default) or to None otherwise. + + Setting gradients to None should save the memory, e.g. + during ``evaluate()`` and thus larger batch might be used. + """ + if self.optimizers_dict is not None: + for opt in self.freeze_optimizers(self.optimizers_dict).values(): + opt.zero_grad(set_to_none=set_to_none) + elif self.opt_class is not None: + self.optimizer.zero_grad(set_to_none=set_to_none) + + def on_evaluate_start(self, max_key=None, min_key=None): + """Gets called at the beginning of ``evaluate()`` + + Default implementation loads the best-performing checkpoint for + evaluation, based on stored metrics. + + Arguments + --------- + max_key : str + Key to use for finding best checkpoint (higher is better). + By default, passed to ``self.checkpointer.recover_if_possible()``. + min_key : str + Key to use for finding best checkpoint (lower is better). + By default, passed to ``self.checkpointer.recover_if_possible()``. + """ + + # Recover best checkpoint for evaluation + if self.checkpointer is not None: + self.checkpointer.recover_if_possible( + max_key=max_key, min_key=min_key + ) + + def fit_batch(self, batch): + """Fit one batch, override to do multiple updates. + + The default implementation depends on a few methods being defined + with a particular behavior: + + * ``compute_forward()`` + * ``compute_objectives()`` + * ``optimizers_step()`` + + Also depends on having optimizers passed at initialization. + + Arguments + --------- + batch : list of torch.Tensors + Batch of data to use for training. Default implementation assumes + this batch has two elements: inputs and targets. + + Returns + ------- + detached loss + """ + should_step = (self.step % self.grad_accumulation_factor) == 0 + self.on_fit_batch_start(batch, should_step) + + with self.no_sync(not should_step): + with self.training_ctx: + outputs = self.compute_forward(batch, sb.Stage.TRAIN) + loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN) + scaled_loss = self.scaler.scale( + loss / self.grad_accumulation_factor + ) + self.check_loss_isfinite(scaled_loss) + scaled_loss.backward() + + if should_step: + self.optimizers_step() + + self.on_fit_batch_end(batch, outputs, loss, should_step) + return loss.detach().cpu() + + def check_loss_isfinite(self, loss): + """Check if the loss is finite. + + If the loss is not finite, log a helpful message and increment the `nonfinite_count`. + If the `nonfinite_count` exceeds the `--nonfinite_patience` threshold, stop the training + and raise an error. + + This check is particularly useful when the loss becomes NaN or inf, while the + parameters and gradients remain finite. It helps prevent getting stuck in an + infinite loop during training. + + Arguments + --------- + loss : tensor + The loss tensor after ``backward()`` has been called but + before the optimizers ``step()``. + """ + if not torch.isfinite(loss): + self.nonfinite_count += 1 + + # Check if patience is exhausted + if self.nonfinite_count > self.nonfinite_patience: + raise ValueError( + "Loss is not finite and patience is exhausted. " + "To debug, wrap `fit()` with " + "autograd's `detect_anomaly()`, e.g.\n\nwith " + "torch.autograd.detect_anomaly():\n\tbrain.fit(...)" + ) + else: + logger.warning("Patience not yet exhausted.") + + def check_gradients(self): + """Checks if the gradients are finite. If not, it will emit a warning and set them to zero.""" + for param in self.modules.parameters(): + if param.requires_grad and param.grad is not None: + if not torch.isfinite(param.grad).all(): + param.grad = None + logger.warning( + f"Gradients {param.name} contain NaN or Inf. Setting to None." + ) + + def freeze_optimizers(self, optimizers): + """By default, this method returns the passed optimizers. + Override this method if you want to freeze some optimizers + during training. To do so, return a of active optimizers. + """ + return optimizers + + def optimizers_step(self): + """Performs a step of gradient descent on the optimizers. This method is called every + ``grad_accumulation_factor`` steps.""" + # 1. get the valid optimizers, i.e., the ones that are not frozen during this step + if self.optimizers_dict is not None: + valid_optimizers = self.freeze_optimizers(self.optimizers_dict) + elif self.opt_class is not None: + # if valid_optimizers is not defined which could happen if a user is using an old + # init_optimizers() method, then we assume that the only valid optimizer is + # self.optimizer (which is the default behavior). + valid_optimizers = {"optimizer": self.optimizer} + else: + # Note: in some cases you might want to only compute gradients statistics and + # you do not need to call the optimizers.step() method. In this case, you can + # simply return from this method and skip the rest of the code. + return + + # 2. unscale the gradients of the valid optimizers + for opt in valid_optimizers.values(): + self.scaler.unscale_(opt) + + # 3. clip gradients + # We are clipping this way because clipping on self.modules.parameters() + # can leads to NaN/Inf gradients norm as doing the concatenation + # of all parameters in a single vector can lead to overflow/underflow. + for opt in valid_optimizers.values(): + torch.nn.utils.clip_grad_norm_( + opt.param_groups[0]["params"], self.max_grad_norm + ) + + # Note: no need to activate this flag if you are in fp16 + # since GradScaler is automatically handling the nonfinite gradients + if not self.scaler.is_enabled() and self.skip_nonfinite_grads: + self.check_gradients() + + # 4. step the valid optimizers + # If the scaler is disable, it simply calls optimizer.step() + for opt in valid_optimizers.values(): + self.scaler.step(opt) + + self.scaler.update() + + for opt in valid_optimizers.values(): + opt.zero_grad(set_to_none=True) + + self.optimizer_step += 1 + + def on_fit_batch_start(self, batch, should_step): + """Called at the beginning of ``fit_batch()``. + + This method is not called under the AMP context manager. Do not assume + automatic casting of the input batch to a lower precision (e.g. fp16). + + Arguments + --------- + batch : list of torch.Tensors + Batch of data to use for training. Default implementation assumes + this batch has two elements: inputs and targets. + should_step : boolean + Whether optimizer.step() was called or not. + """ + pass + + def on_fit_batch_end(self, batch, outputs, loss, should_step): + """Called after ``fit_batch()``. + + Arguments + --------- + batch : list of torch.Tensors + Batch of data to use for training. Default implementation assumes + this batch has two elements: inputs and targets. + outputs : list or dictionary of torch.Tensors + Returned value of compute_forward(). + loss : torch.Tensor + Returned value of compute_objectives(). + should_step : boolean + Whether optimizer.step() was called or not. + """ + pass + + @torch.no_grad() + def evaluate_batch(self, batch, stage): + """Evaluate one batch, override for different procedure than train. + + The default implementation depends on two methods being defined + with a particular behavior: + + * ``compute_forward()`` + * ``compute_objectives()`` + + Arguments + --------- + batch : list of torch.Tensors + Batch of data to use for evaluation. Default implementation assumes + this batch has two elements: inputs and targets. + stage : Stage + The stage of the experiment: Stage.VALID, Stage.TEST + + Returns + ------- + detached loss + """ + with self.evaluation_ctx: + out = self.compute_forward(batch, stage=stage) + loss = self.compute_objectives(out, batch, stage=stage) + return loss.detach().cpu() + + def _fit_train(self, train_set, epoch, enable): + # Training stage + self.on_stage_start(Stage.TRAIN, epoch) + self.modules.train() + self.zero_grad() + + # Reset nonfinite count to 0 each epoch + self.nonfinite_count = 0 + + if self.train_sampler is not None and hasattr( + self.train_sampler, "set_epoch" + ): + self.train_sampler.set_epoch(epoch) + + # Time since last intra-epoch checkpoint + last_ckpt_time = time.time() + steps_since_ckpt = 0 + with tqdm( + train_set, + initial=self.step, + dynamic_ncols=True, + disable=not enable, + colour=self.tqdm_barcolor["train"], + ) as t: + if self.profiler is not None: + self.profiler.start() + for batch in t: + if self._optimizer_step_limit_exceeded: + logger.info("Train iteration limit exceeded") + break + self.step += 1 + steps_since_ckpt += 1 + loss = self.fit_batch(batch) + self.avg_train_loss = self.update_average( + loss, self.avg_train_loss + ) + t.set_postfix(train_loss=self.avg_train_loss) + + if self.profiler is not None: + self.profiler.step() + if self.profiler.step_num > self.tot_prof_steps: + logger.info( + "The profiler finished, training is stopped." + ) + self.profiler.stop() + quit() + + # Debug mode only runs a few batches + if self.debug and self.step == self.debug_batches: + break + + if self._should_save_intra_epoch_ckpt( + last_ckpt_time, steps_since_ckpt + ): + # Checkpointer class will handle running this on main only + self._save_intra_epoch_ckpt() + last_ckpt_time = time.time() + steps_since_ckpt = 0 + + # Run train "on_stage_end" on all processes + self.zero_grad(set_to_none=True) # flush gradients + self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch) + self.avg_train_loss = 0.0 + self.step = 0 + + def _should_save_intra_epoch_ckpt(self, last_ckpt_time, steps_since_ckpt): + """Determines if an intra-epoch checkpoint should be saved. + + Returns True if there's a checkpointer and time or steps has exceeded limit. + """ + if self.checkpointer is None: + return False + + # Return early if mid-epoch checkpoints are disabled to avoid sync + if self.ckpt_interval_minutes <= 0 and self.ckpt_interval_steps <= 0: + return False + + # Check if we've run for the requested amount of time + elapsed_minutes = (time.time() - last_ckpt_time) / 60.0 + decision = 0 < self.ckpt_interval_minutes < elapsed_minutes + + # Save after requested # of steps + decision = decision or 0 < self.ckpt_interval_steps <= steps_since_ckpt + + # If the program is not distributed, just return + if not is_distributed_initialized(): + return decision + + # Otherwise, broadcast decision to all processes from main (rank 0) + # This solves synchronization issues where main gets a different + # timing result than the other processes. + else: + broadcast_list = [decision] + torch.distributed.broadcast_object_list(broadcast_list, src=0) + return broadcast_list[0] + + def _fit_valid(self, valid_set, epoch, enable): + # Validation stage + if valid_set is not None: + self.on_stage_start(Stage.VALID, epoch) + self.modules.eval() + avg_valid_loss = 0.0 + with torch.no_grad(): + for batch in tqdm( + valid_set, + dynamic_ncols=True, + disable=not enable, + colour=self.tqdm_barcolor["valid"], + ): + self.step += 1 + loss = self.evaluate_batch(batch, stage=Stage.VALID) + avg_valid_loss = self.update_average(loss, avg_valid_loss) + + # Debug mode only runs a few batches + if self.debug and self.step == self.debug_batches: + break + + self.step = 0 + self.on_stage_end(Stage.VALID, avg_valid_loss, epoch) + + def fit( + self, + epoch_counter, + train_set, + valid_set=None, + progressbar=None, + train_loader_kwargs={}, + valid_loader_kwargs={}, + ): + """Iterate epochs and datasets to improve objective. + + Relies on the existence of multiple functions that can (or should) be + overridden. The following methods are used and expected to have a + certain behavior: + + * ``fit_batch()`` + * ``evaluate_batch()`` + * ``update_average()`` + + If the initialization was done with distributed_count > 0 and the + distributed_backend is ddp, this will generally handle multiprocess + logic, like splitting the training data into subsets for each device and + only saving a checkpoint on the main process. + + Arguments + --------- + epoch_counter : iterable + Each call should return an integer indicating the epoch count. + train_set : Dataset, DataLoader + A set of data to use for training. If a Dataset is given, a + DataLoader is automatically created. If a DataLoader is given, it is + used directly. + valid_set : Dataset, DataLoader + A set of data to use for validation. If a Dataset is given, a + DataLoader is automatically created. If a DataLoader is given, it is + used directly. + progressbar : bool + Whether to display the progress of each epoch in a progressbar. + train_loader_kwargs : dict + Kwargs passed to `make_dataloader()` for making the train_loader + (if train_set is a Dataset, not DataLoader). + E.G. batch_size, num_workers. + DataLoader kwargs are all valid. + valid_loader_kwargs : dict + Kwargs passed to `make_dataloader()` for making the valid_loader + (if valid_set is a Dataset, not DataLoader). + E.g., batch_size, num_workers. + DataLoader kwargs are all valid. + + Returns + ------- + None + """ + if self.test_only: + logger.info( + "Test only mode, skipping training and validation stages." + ) + return + + if not ( + isinstance(train_set, DataLoader) + or isinstance(train_set, LoopedLoader) + ): + train_set = self.make_dataloader( + train_set, stage=sb.Stage.TRAIN, **train_loader_kwargs + ) + if valid_set is not None and not ( + isinstance(valid_set, DataLoader) + or isinstance(valid_set, LoopedLoader) + ): + valid_set = self.make_dataloader( + valid_set, + stage=sb.Stage.VALID, + ckpt_prefix=None, + **valid_loader_kwargs, + ) + + self.on_fit_start() + + if progressbar is None: + progressbar = not self.noprogressbar + + # Only show progressbar if requested and main_process + enable = progressbar and sb.utils.distributed.if_main_process() + + # Iterate epochs + for epoch in epoch_counter: + self._fit_train(train_set=train_set, epoch=epoch, enable=enable) + self._fit_valid(valid_set=valid_set, epoch=epoch, enable=enable) + + # Debug mode only runs a few epochs + if ( + self.debug + and epoch == self.debug_epochs + or self._optimizer_step_limit_exceeded + ): + break + + @property + def _optimizer_step_limit_exceeded(self): + return ( + self.optimizer_step_limit is not None + and self.optimizer_step >= self.optimizer_step_limit + ) + + def _save_intra_epoch_ckpt(self): + """Saves a CKPT with specific intra-epoch flag.""" + self.checkpointer.save_and_keep_only( + end_of_epoch=False, + num_to_keep=1, + ckpt_predicate=lambda c: INTRA_EPOCH_CKPT_FLAG in c.meta, + meta={INTRA_EPOCH_CKPT_FLAG: True}, + verbosity=logging.DEBUG, + ) + + def _compile(self): + """Compile requested modules with either JIT or TorchInductor.""" + compile_available = hasattr(torch, "compile") + + if not compile_available and self.compile_module_keys is not None: + raise ValueError( + "'compile_module_keys' specified, but this install of PyTorch " + "seems to be too old to support it." + ) + # Modules to compile with torch.compile + compile_module_keys = set() + if self.compile: + if self.compile_module_keys is None: + compile_module_keys = set(self.modules) + else: + compile_module_keys = set(self.compile_module_keys) + logger.warning( + "--compile and --compile_module_keys are both specified. " + "Only modules specified in --compile_module_keys will be compiled." + ) + + # Modules to compile with jit + jit_module_keys = set() + if self.jit: + if self.jit_module_keys is None: + jit_module_keys = set(self.modules) + else: + jit_module_keys = set(self.jit_module_keys) + logger.warning( + "--jit and --jit_module_keys are both specified. " + "Only modules specified in --jit_module_keys will be compiled." + ) + + # find missing keys + for name in compile_module_keys | jit_module_keys: + if name not in self.modules: + raise ValueError( + f"module {name} is not defined in your hparams file." + ) + + # try 'torch.compile', remove successful compiles from JIT list + for name in compile_module_keys: + try: + module = torch.compile( + self.modules[name], + mode=self.compile_mode, + fullgraph=self.compile_using_fullgraph, + dynamic=self.compile_using_dynamic_shape_tracing, + ) + except Exception as e: + logger.warning( + f"'{name}' in 'compile_module_keys' failed to compile " + f"and will be skipped (may fallback onto JIT, if " + f"specified): {e}" + ) + continue + + self.modules[name] = module.to(self.device) + jit_module_keys.discard(name) + + for name in jit_module_keys: + module = torch.jit.script(self.modules[name]) + self.modules[name] = module.to(self.device) + + def _wrap_distributed(self): + """Wrap modules with distributed wrapper when requested.""" + if not self.distributed_launch and not self.data_parallel_backend: + return + elif self.distributed_launch: + for name, module in self.modules.items(): + if any(p.requires_grad for p in module.parameters()): + module = SyncBatchNorm.convert_sync_batchnorm(module) + if self.distributed_backend == "gloo": + module = DDP( + module, + device_ids=None, + find_unused_parameters=self.find_unused_parameters, + ) + else: + module = DDP( + module, + device_ids=[self.device], + find_unused_parameters=self.find_unused_parameters, + ) + self.modules[name] = module + else: + # data_parallel_backend + for name, module in self.modules.items(): + if any(p.requires_grad for p in module.parameters()): + module = DP(module) + self.modules[name] = module + + def evaluate( + self, + test_set, + max_key=None, + min_key=None, + progressbar=None, + test_loader_kwargs={}, + ): + """Iterate test_set and evaluate brain performance. By default, loads + the best-performing checkpoint (as recorded using the checkpointer). + + Arguments + --------- + test_set : Dataset, DataLoader + If a DataLoader is given, it is iterated directly. Otherwise passed + to ``self.make_dataloader()``. + max_key : str + Key to use for finding best checkpoint, passed to + ``on_evaluate_start()``. + min_key : str + Key to use for finding best checkpoint, passed to + ``on_evaluate_start()``. + progressbar : bool + Whether to display the progress in a progressbar. + test_loader_kwargs : dict + Kwargs passed to ``make_dataloader()`` if ``test_set`` is not a + DataLoader. NOTE: ``loader_kwargs["ckpt_prefix"]`` gets + automatically overwritten to ``None`` (so that the test DataLoader + is not added to the checkpointer). + + Returns + ------- + average test loss + """ + if progressbar is None: + progressbar = not self.noprogressbar + + # Only show progressbar if requested and main_process + enable = progressbar and sb.utils.distributed.if_main_process() + + if not ( + isinstance(test_set, DataLoader) + or isinstance(test_set, LoopedLoader) + ): + test_loader_kwargs["ckpt_prefix"] = None + test_set = self.make_dataloader( + test_set, Stage.TEST, **test_loader_kwargs + ) + self.on_evaluate_start(max_key=max_key, min_key=min_key) + self.on_stage_start(Stage.TEST, epoch=None) + self.modules.eval() + avg_test_loss = 0.0 + with torch.no_grad(): + for batch in tqdm( + test_set, + dynamic_ncols=True, + disable=not enable, + colour=self.tqdm_barcolor["test"], + ): + self.step += 1 + loss = self.evaluate_batch(batch, stage=Stage.TEST) + avg_test_loss = self.update_average(loss, avg_test_loss) + + # Debug mode only runs a few batches + if self.debug and self.step == self.debug_batches: + break + + self.on_stage_end(Stage.TEST, avg_test_loss, None) + self.step = 0 + return avg_test_loss + + def update_average(self, loss, avg_loss): + """Update running average of the loss. + + Arguments + --------- + loss : torch.tensor + detached loss, a single float value. + avg_loss : float + current running average. + + Returns + ------- + avg_loss : float + The average loss. + """ + if torch.isfinite(loss): + avg_loss -= avg_loss / self.step + avg_loss += float(loss) / self.step + return avg_loss + + @contextmanager + def no_sync(self, use=True): + """Copies pytorch's implementation for doing no_sync across all modules. + + Explanation: nn.module.no_sync() is a context manager for when one does + not want to sync gradients, which happens when using both DDP and gradient accumulation. + Speechbrain brain's class can contain multiple modules and calling no_sync on these + individually would be very awkward, therefore this contextmanager exists. + + Arguments + --------- + use : bool + If set to `False` will still sync gradients, useful to make behavior toggleable. + + Yields + ------ + None + """ + if use: + old_values_list = [] + for module in self.modules.values(): + if not hasattr(module, "require_backward_grad_sync"): + # if not using DDP + continue + old_values_list.append(module.require_backward_grad_sync) + module.require_backward_grad_sync = False + yield + i = 0 + for module in self.modules.values(): + if not hasattr(module, "require_backward_grad_sync"): + continue + module.require_backward_grad_sync = old_values_list[i] + i += 1 + else: + yield + + @sb.utils.checkpoints.mark_as_saver + def _save(self, path): + save_dict = { + "step": self.step, + "avg_train_loss": self.avg_train_loss, + "optimizer_step": self.optimizer_step, + } + with open(path, "w", encoding="utf-8") as w: + w.write(yaml.dump(save_dict)) + + @sb.utils.checkpoints.mark_as_loader + def _recover(self, path, end_of_epoch): + del end_of_epoch + with open(path, encoding="utf-8") as f: + save_dict = yaml.safe_load(f) + self.step = save_dict["step"] + self.avg_train_loss = save_dict["avg_train_loss"] + # Ensure compatibility with checkpoints from before optimizer_step: + if "optimizer_step" not in save_dict: + clsname = self.__class__.__name__ + MSG = f"'optimizer_step' not found in {clsname} checkpoint." + MSG += " Using the saved 'step' value (BACKWARDS COMPATIBILITY)" + warnings.warn(MSG) + self.optimizer_step = self.step + else: + self.optimizer_step = save_dict["optimizer_step"] diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/__init__.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/__init__.py new file mode 100644 index 000000000..3b2b7ab4b --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/__init__.py @@ -0,0 +1,5 @@ +"""Data loading and dataset preprocessing""" + +from speechbrain.utils.importutils import lazy_export_all + +lazy_export_all(__file__, __name__, export_subpackages=True) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/audio_io.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/audio_io.py new file mode 100644 index 000000000..821be3c2b --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/audio_io.py @@ -0,0 +1,228 @@ +""" +Lightweight soundfile-based audio I/O compatibility layer. + +This module provides a minimal compatibility wrapper for audio I/O operations +using soundfile (pysoundfile) library, replacing torchaudio's load, save, and +info functions. + +Example +------- +>>> from speechbrain.dataio import audio_io +>>> import torch +>>> # Save audio file +>>> waveform = torch.randn(1, 16000) +>>> tmpdir = getfixture("tmpdir") +>>> audio_io.save(tmpdir / "example.wav", waveform, 16000) +>>> # Load audio file +>>> audio, sr = audio_io.load(tmpdir / "example.wav") +>>> # Get audio metadata +>>> info = audio_io.info(tmpdir / "example.wav") +>>> info.duration +1.0 + +Authors + * Peter Plantinga 2025 +""" + +import dataclasses + +import numpy as np +import soundfile as sf +import torch + + +@dataclasses.dataclass +class AudioInfo: + """Container for audio file metadata, compatible with torchaudio.info output. + + Attributes + ---------- + sample_rate : int + Sample rate of the audio file. + frames : int + Total number of frames in the audio file. + channels : int + Number of audio channels. + subtype : str + Audio subtype/encoding (e.g., 'PCM_16', 'PCM_24'). + format : str + Container format (e.g., 'WAV', 'FLAC'). + """ + + sample_rate: int + frames: int + channels: int + subtype: str + format: str + + @property + def num_frames(self): + """Alias for frames for compatibility.""" + return self.frames + + @property + def num_channels(self): + """Alias for channels for compatibility.""" + return self.channels + + @property + def duration(self): + """Calculate duration in seconds.""" + return self.frames / self.sample_rate if self.sample_rate > 0 else 0.0 + + +def load( + path, + *, + channels_first=True, + dtype=None, + always_2d=True, + frame_offset=0, + num_frames=-1, +): + """Load audio file using soundfile. + + Arguments + --------- + path : str + Path to the audio file. + channels_first : bool + If True, returns tensor with shape (channels, frames). + If False, returns tensor with shape (frames, channels). + Ignored if `always_2d` is False and input is mono. + Default: True. + dtype : torch.dtype, optional + Data type for the output tensor. Respects default torch type. + If the dtype is not one of the available dtypes in soundfile, loads + with float32 first and then converts to the requested dtype. + always_2d : bool + If True, always return a 2D tensor even for mono audio. + If False, mono audio returns a 1D tensor (frames,). + Default: True. + frame_offset : int + Number of frames to skip at the start of the file. Default: 0. + num_frames : int + Number of frames to read. If -1, reads to the end of the file. Default: -1. + + Returns + ------- + tensor : torch.Tensor + Audio waveform as a tensor. + sample_rate : int + Sample rate of the audio file. + """ + try: + # Compute type for loading + dtype = dtype or torch.get_default_dtype() + _, dtype_string = str(dtype).split(".") + + # If the selected dtype is not a valid soundfile type, just use float32 + if dtype_string not in sf._ffi_types: + dtype_string = "float32" + + # Read audio file - soundfile returns (frames, channels) or (frames,) for mono + audio_np, sample_rate = sf.read( + path, + start=frame_offset, + frames=num_frames, + dtype=dtype_string, + always_2d=always_2d, + ) + + # Convert to torch tensor + audio = torch.from_numpy(audio_np).to(dtype) + + # Convert from (frames, channels) to (channels, frames) + if audio.ndim == 2 and channels_first: + audio = audio.transpose(0, 1) + + return audio, int(sample_rate) + + except Exception as e: + raise RuntimeError(f"Failed to load audio from {path}: {e}") from e + + +def save(path, src, sample_rate, channels_first=True, subtype=None): + """Save audio to file using soundfile. + + Arguments + --------- + path : str + Path where to save the audio file. + src : torch.Tensor or numpy.ndarray + Audio waveform. Can be: + - 1D tensor/array: (frames,) - mono + - 2D tensor/array: + - (channels, frames) if channels_first=True + - (frames, channels) if channels_first=False + sample_rate : int + Sample rate for the audio file. + channels_first : bool + If True, input is assumed to be (channels, frames) + If False, input is assumed to be (frames, channels). + Ignored if input is 1D tensor/array. + Default: True. + subtype : str, optional + Audio encoding subtype (e.g., 'PCM_16', 'PCM_24', 'PCM_32', 'FLOAT'). + If None, soundfile will choose an appropriate subtype based on the file format. + Default: None. + """ + try: + # Convert to numpy if needed + if isinstance(src, torch.Tensor): + audio_np = src.detach().cpu().numpy() + else: + audio_np = np.asarray(src) + + # Convert to (frames, channels) if channels_first is True + if audio_np.ndim == 2 and channels_first: + audio_np = audio_np.T + + if audio_np.ndim not in [1, 2]: + raise ValueError( + f"Unsupported audio shape: {audio_np.shape}. " + "Expected 1D frames or 2D channels and frames." + ) + + sf.write(path, audio_np, sample_rate, subtype=subtype) + + except Exception as e: + raise RuntimeError(f"Failed to save audio to {path}: {e}") from e + + +def info(path): + """Get audio file metadata using soundfile. + + Arguments + --------- + path : str + Path to the audio file. + + Returns + ------- + AudioInfo + Object containing audio metadata (sample_rate, frames, channels, + subtype, format, duration). + """ + try: + file_info = sf.info(path) + return AudioInfo( + sample_rate=file_info.samplerate, + frames=file_info.frames, + channels=file_info.channels, + subtype=file_info.subtype, + format=file_info.format, + ) + except Exception as e: + raise RuntimeError(f"Failed to get info for {path}: {e}") from e + + +def list_audio_backends(): + """List available audio backends. + + Returns + ------- + list of str + List of available backend names. Currently only ['soundfile']. + """ + return ["soundfile"] diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/batch.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/batch.py new file mode 100644 index 000000000..b0fa21071 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/batch.py @@ -0,0 +1,333 @@ +"""Batch collation + +Authors + * Aku Rouhe 2020 +""" + +import collections + +import torch +from torch.utils.data._utils.collate import default_convert +from torch.utils.data._utils.pin_memory import ( + pin_memory as recursive_pin_memory, +) + +from speechbrain.utils.data_utils import ( + batch_pad_right, + mod_default_collate, + recursive_to, +) + +PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"]) + + +class PaddedBatch: + """Collate_fn when examples are dicts and have variable-length sequences. + + Different elements in the examples get matched by key. + All numpy tensors get converted to Torch (PyTorch default_convert) + Then, by default, all torch.Tensor valued elements get padded and support + collective pin_memory() and to() calls. + Regular Python data types are just collected in a list. + + Arguments + --------- + examples : list + List of example dicts, as produced by Dataloader. + padded_keys : list, None + (Optional) List of keys to pad on. If None, pad all torch.Tensors + device_prep_keys : list, None + (Optional) Only these keys participate in collective memory pinning and moving with + to(). + If None, defaults to all items with torch.Tensor values. + padding_func : callable, optional + Called with a list of tensors to be padded together. Needs to return + two tensors: the padded data, and another tensor for the data lengths. + padding_kwargs : dict, None + (Optional) Extra kwargs to pass to padding_func. E.G. mode, value + This is used as the default padding configuration for all keys. + per_key_padding_kwargs : dict, None + (Optional) Per-key padding configuration. Keys in this dict should match + the keys in the examples. Each value should be a dict with padding parameters + (e.g., {'value': -100, 'mode': 'constant'}). If a key is not in this dict, + the global padding_kwargs will be used. + apply_default_convert : bool + Whether to apply PyTorch default_convert (numpy to torch recursively, + etc.) on all data. Default:True, usually does the right thing. + nonpadded_stack : bool + Whether to apply PyTorch-default_collate-like stacking on values that + didn't get padded. This stacks if it can, but doesn't error out if it + cannot. Default:True, usually does the right thing. + + Example + ------- + >>> batch = PaddedBatch( + ... [ + ... {"id": "ex1", "foo": torch.Tensor([1.0])}, + ... {"id": "ex2", "foo": torch.Tensor([2.0, 1.0])}, + ... ] + ... ) + >>> # Attribute or key-based access: + >>> batch.id + ['ex1', 'ex2'] + >>> batch["id"] + ['ex1', 'ex2'] + >>> # torch.Tensors get padded + >>> type(batch.foo) + + >>> batch.foo.data + tensor([[1., 0.], + [2., 1.]]) + >>> batch.foo.lengths + tensor([0.5000, 1.0000]) + >>> # Batch supports collective operations: + >>> _ = batch.to(dtype=torch.half) + >>> batch.foo.data + tensor([[1., 0.], + [2., 1.]], dtype=torch.float16) + >>> batch.foo.lengths + tensor([0.5000, 1.0000], dtype=torch.float16) + >>> # Numpy tensors get converted to torch and padded as well: + >>> import numpy as np + >>> batch = PaddedBatch( + ... [{"wav": np.asarray([1, 2, 3, 4])}, {"wav": np.asarray([1, 2, 3])}] + ... ) + >>> batch.wav # +ELLIPSIS + PaddedData(data=tensor([[1, 2,... + >>> # Basic stacking collation deals with non padded data: + >>> batch = PaddedBatch( + ... [ + ... { + ... "spk_id": torch.tensor([1]), + ... "wav": torch.tensor([0.1, 0.0, 0.3]), + ... }, + ... { + ... "spk_id": torch.tensor([2]), + ... "wav": torch.tensor([0.2, 0.3, -0.1]), + ... }, + ... ], + ... padded_keys=["wav"], + ... ) + >>> batch.spk_id + tensor([[1], + [2]]) + >>> # And some data is left alone: + >>> batch = PaddedBatch( + ... [{"text": ["Hello"]}, {"text": ["How", "are", "you?"]}] + ... ) + >>> batch.text + [['Hello'], ['How', 'are', 'you?']] + >>> # Per-key padding configuration: + >>> batch = PaddedBatch( + ... [ + ... { + ... "wav": torch.tensor([1, 2, 3]), + ... "labels": torch.tensor([1, 2]), + ... }, + ... {"wav": torch.tensor([4, 5]), "labels": torch.tensor([3])}, + ... ], + ... per_key_padding_kwargs={ + ... "wav": {"value": 0}, + ... "labels": {"value": -100}, + ... }, + ... ) + >>> batch.wav.data + tensor([[1, 2, 3], + [4, 5, 0]]) + >>> batch.labels.data + tensor([[ 1, 2], + [ 3, -100]]) + + """ + + def __init__( + self, + examples, + padded_keys=None, + device_prep_keys=None, + padding_func=batch_pad_right, + padding_kwargs=None, + per_key_padding_kwargs=None, + apply_default_convert=True, + nonpadded_stack=True, + ): + padding_kwargs = padding_kwargs if padding_kwargs is not None else {} + per_key_padding_kwargs = ( + per_key_padding_kwargs if per_key_padding_kwargs is not None else {} + ) + self.__length = len(examples) + self.__keys = list(examples[0].keys()) + self.__padded_keys = [] + self.__device_prep_keys = [] + for key in self.__keys: + values = [example[key] for example in examples] + # Default convert usually does the right thing (numpy2torch etc.) + if apply_default_convert: + values = default_convert(values) + if (padded_keys is not None and key in padded_keys) or ( + padded_keys is None and isinstance(values[0], torch.Tensor) + ): + # Padding and PaddedData + self.__padded_keys.append(key) + + # Use per-key padding config if available, otherwise fall back to global padding_kwargs + if key in per_key_padding_kwargs: + key_padding_kwargs = per_key_padding_kwargs[key] + else: + key_padding_kwargs = padding_kwargs + padded = PaddedData(*padding_func(values, **key_padding_kwargs)) + setattr(self, key, padded) + else: + # Default PyTorch collate usually does the right thing + # (convert lists of equal sized tensors to batch tensors, etc.) + if nonpadded_stack: + values = mod_default_collate(values) + setattr(self, key, values) + if (device_prep_keys is not None and key in device_prep_keys) or ( + device_prep_keys is None and isinstance(values[0], torch.Tensor) + ): + self.__device_prep_keys.append(key) + + def __len__(self): + return self.__length + + def __getitem__(self, key): + if key in self.__keys: + return getattr(self, key) + else: + raise KeyError(f"Batch doesn't have key: {key}") + + def __iter__(self): + """Iterates over the different elements of the batch. + + Returns + ------- + Iterator over the batch. + + Example + ------- + >>> batch = PaddedBatch( + ... [ + ... {"id": "ex1", "val": torch.Tensor([1.0])}, + ... {"id": "ex2", "val": torch.Tensor([2.0, 1.0])}, + ... ] + ... ) + >>> ids, vals = batch + >>> ids + ['ex1', 'ex2'] + """ + return iter(getattr(self, key) for key in self.__keys) + + def pin_memory(self): + """In-place, moves relevant elements to pinned memory.""" + for key in self.__device_prep_keys: + value = getattr(self, key) + pinned = recursive_pin_memory(value) + setattr(self, key, pinned) + return self + + def to(self, *args, **kwargs): + """In-place move/cast relevant elements. + + Passes all arguments to torch.Tensor.to, see its documentation. + """ + for key in self.__device_prep_keys: + value = getattr(self, key) + moved = recursive_to(value, *args, **kwargs) + setattr(self, key, moved) + return self + + def at_position(self, pos): + """Gets the position.""" + key = self.__keys[pos] + return getattr(self, key) + + @property + def batchsize(self): + """Returns the bach size""" + return self.__length + + +class BatchsizeGuesser: + """Try to figure out the batchsize, but never error out + + If this cannot figure out anything else, will fallback to guessing 1 + + Example + ------- + >>> guesser = BatchsizeGuesser() + >>> # Works with simple tensors: + >>> guesser(torch.randn((2, 3))) + 2 + >>> # Works with sequences of tensors: + >>> guesser((torch.randn((2, 3)), torch.randint(high=5, size=(2,)))) + 2 + >>> # Works with PaddedBatch: + >>> guesser( + ... PaddedBatch([{"wav": [1.0, 2.0, 3.0]}, {"wav": [4.0, 5.0, 6.0]}]) + ... ) + 2 + >>> guesser("Even weird non-batches have a fallback") + 1 + + """ + + def __init__(self): + self.method = None + + def __call__(self, batch): + try: + return self.method(batch) + except: # noqa: E722 + return self.find_suitable_method(batch) + + def find_suitable_method(self, batch): + """Try the different methods and note which worked""" + try: + bs = self.attr_based(batch) + self.method = self.attr_based + return bs + except: # noqa: E722 + pass + try: + bs = self.torch_tensor_bs(batch) + self.method = self.torch_tensor_bs + return bs + except: # noqa: E722 + pass + try: + bs = self.len_of_first(batch) + self.method = self.len_of_first + return bs + except: # noqa: E722 + pass + try: + bs = self.len_of_iter_first(batch) + self.method = self.len_of_iter_first + return bs + except: # noqa: E722 + pass + # Last ditch fallback: + bs = self.fallback(batch) + self.method = self.fallback(batch) + return bs + + def attr_based(self, batch): + """Implementation of attr_based.""" + return batch.batchsize + + def torch_tensor_bs(self, batch): + """Implementation of torch_tensor_bs.""" + return batch.shape[0] + + def len_of_first(self, batch): + """Implementation of len_of_first.""" + return len(batch[0]) + + def len_of_iter_first(self, batch): + """Implementation of len_of_iter_first.""" + return len(next(iter(batch))) + + def fallback(self, batch): + """Implementation of fallback.""" + return 1 diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/dataio.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/dataio.py new file mode 100644 index 000000000..0385ade1c --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/dataio.py @@ -0,0 +1,1417 @@ +""" +Data reading and writing. + +Authors + * Mirco Ravanelli 2020 + * Aku Rouhe 2020 + * Ju-Chieh Chou 2020 + * Samuele Cornell 2020 + * Abdel HEBA 2020 + * Gaëlle Laperrière 2021 + * Sahar Ghannay 2021 + * Sylvain de Langen 2022 + * Adel Moumen 2025 +""" + +import csv +import hashlib +import json +import os +import pickle +import re +import time +from io import BytesIO +from typing import Union + +import numpy as np +import torch + +from speechbrain.dataio import audio_io +from speechbrain.utils.logger import get_logger +from speechbrain.utils.torch_audio_backend import ( + check_torchaudio_backend, + validate_backend, +) + +check_torchaudio_backend() +logger = get_logger(__name__) + + +def load_data_json(json_path, replacements=None): + """Loads JSON and recursively formats string values. + + Arguments + --------- + json_path : str + Path to CSV file. + replacements : dict + (Optional dict), e.g., {"data_folder": "/home/speechbrain/data"}. + This is used to recursively format all string values in the data. + + Returns + ------- + dict + JSON data with replacements applied. + + Example + ------- + >>> json_spec = '''{ + ... "ex1": {"files": ["{ROOT}/mic1/ex1.wav", "{ROOT}/mic2/ex1.wav"], "id": 1}, + ... "ex2": {"files": [{"spk1": "{ROOT}/ex2.wav"}, {"spk2": "{ROOT}/ex2.wav"}], "id": 2} + ... } + ... ''' + >>> tmpfile = getfixture("tmpdir") / "test.json" + >>> with open(tmpfile, "w", encoding="utf-8") as fo: + ... _ = fo.write(json_spec) + >>> data = load_data_json(tmpfile, {"ROOT": "/home"}) + >>> data["ex1"]["files"][0] + '/home/mic1/ex1.wav' + >>> data["ex2"]["files"][1]["spk2"] + '/home/ex2.wav' + + """ + if replacements is None: + replacements = {} + with open(json_path, encoding="utf-8") as f: + out_json = json.load(f) + _recursive_format(out_json, replacements) + return out_json + + +def _recursive_format(data, replacements): + # Data: dict or list, replacements : dict + # Replaces string keys in replacements by their values + # at all levels of data (in str values) + # Works in-place. + if isinstance(data, dict): + for key, item in data.items(): + if isinstance(item, dict) or isinstance(item, list): + _recursive_format(item, replacements) + elif isinstance(item, str): + data[key] = item.format_map(replacements) + # If not dict, list or str, do nothing + if isinstance(data, list): + for i, item in enumerate(data): + if isinstance(item, dict) or isinstance(item, list): + _recursive_format(item, replacements) + elif isinstance(item, str): + data[i] = item.format_map(replacements) + # If not dict, list or str, do nothing + + +def load_data_csv(csv_path, replacements=None): + """Loads CSV and formats string values. + + Uses the SpeechBrain legacy CSV data format, where the CSV must have an + 'ID' field. + If there is a field called duration, it is interpreted as a float. + The rest of the fields are left as they are (legacy _format and _opts fields + are not used to load the data in any special way). + + Bash-like string replacements with $to_replace are supported. + + Arguments + --------- + csv_path : str + Path to CSV file. + replacements : dict + (Optional dict), e.g., {"data_folder": "/home/speechbrain/data"} + This is used to recursively format all string values in the data. + + Returns + ------- + dict + CSV data with replacements applied. + + Example + ------- + >>> csv_spec = '''ID,duration,wav_path + ... utt1,1.45,$data_folder/utt1.wav + ... utt2,2.0,$data_folder/utt2.wav + ... ''' + >>> tmpfile = getfixture("tmpdir") / "test.csv" + >>> with open(tmpfile, "w", encoding="utf-8") as fo: + ... _ = fo.write(csv_spec) + >>> data = load_data_csv(tmpfile, {"data_folder": "/home"}) + >>> data["utt1"]["wav_path"] + '/home/utt1.wav' + """ + + if replacements is None: + replacements = {} + with open(csv_path, newline="", encoding="utf-8") as csvfile: + result = {} + reader = csv.DictReader(csvfile, skipinitialspace=True) + variable_finder = re.compile(r"\$([\w.]+)") + for row in reader: + # ID: + try: + data_id = row["ID"] + del row["ID"] # This is used as a key in result, instead. + except KeyError: + raise KeyError( + "CSV has to have an 'ID' field, with unique ids" + " for all data points" + ) + if data_id in result: + raise ValueError(f"Duplicate id: {data_id}") + # Replacements: + for key, value in row.items(): + try: + row[key] = variable_finder.sub( + lambda match: str(replacements[match[1]]), value + ) + except KeyError: + raise KeyError( + f"The item {value} requires replacements " + "which were not supplied." + ) + # Duration: + if "duration" in row: + row["duration"] = float(row["duration"]) + result[data_id] = row + return result + + +def read_audio_info(path, backend=None) -> "audio_io.AudioInfo": + """Retrieves audio metadata from a file path. Uses audio_io.info which is + based on soundfile. + + Note that this may cause full file traversal in certain cases! + + Arguments + --------- + path : str + Path to the audio file to examine. + backend : str, optional + Audio backend to use for loading the audio file. This parameter is + kept for compatibility but is currently ignored (soundfile is always used). + + Returns + ------- + audio_io.AudioInfo + Audio metadata with fields: sample_rate, num_frames, channels, etc. + + NOTE + ---- + Some codecs, such as MP3, require full file traversal for accurate length + information to be retrieved. + In these cases, you may as well read the entire audio file to avoid doubling + the processing time. + """ + if backend is not None: + validate_backend(backend) + + # Use audio_io.info which is based on soundfile + info = audio_io.info(path) + + # Soundfile generally provides reliable frame counts, but if for some + # reason num_frames is 0, we can fall back to loading the file + if info.num_frames == 0: + channels_data, sample_rate = audio_io.load(path) + info.num_frames = channels_data.size(-1) # frames dimension + info.sample_rate = sample_rate + + return info + + +def read_audio(waveforms_obj, backend=None): + """General audio loading, based on a custom notation. + + Expected use case is in conjunction with Datasets + specified by JSON. + + The parameter may just be a path to a file: + `read_audio("/path/to/wav1.wav")` + + Alternatively, you can specify more options in a dict, e.g.: + ``` + # load a file from sample 8000 through 15999 + read_audio({"file": "/path/to/wav2.wav", "start": 8000, "stop": 16000}) + ``` + + Which codecs are supported depends on the soundfile library. + Refer to `audio_io.load` documentation for further details. + + Arguments + --------- + waveforms_obj : str, dict + Path to audio or dict with the desired configuration. + + Keys for the dict variant: + - `"file"` (str): Path to the audio file. + - `"start"` (int, optional): The first sample to load. + If unspecified, load from the very first frame. + - `"stop"` (int, optional): The last sample to load (exclusive). + If unspecified or equal to start, load from `start` to the end. + Will not fail if `stop` is past the sample count of the file and will + return less frames. + backend : str, optional + Audio backend to use for loading the audio file. Must be one of + 'ffmpeg', 'sox', 'soundfile' or None. If None, uses torchaudio's default backend. + + Returns + ------- + torch.Tensor + 1-channel: audio tensor with shape: `(samples, )`. + >=2-channels: audio tensor with shape: `(samples, channels)`. + + Raises + ------ + ValueError + If the `backend` is not one of the allowed values. + Must be one of [None, 'ffmpeg', 'sox', 'soundfile']. + + Example + ------- + >>> dummywav = torch.rand(16000) + >>> import os + >>> tmpfile = str(getfixture("tmpdir") / "wave.wav") + >>> write_audio(tmpfile, dummywav, 16000) + >>> asr_example = {"wav": tmpfile, "spk_id": "foo", "words": "foo bar"} + >>> loaded = read_audio(asr_example["wav"]) + >>> loaded.allclose( + ... dummywav.squeeze(0), atol=1e-4 + ... ) # replace with eq with sox_io backend + True + """ + validate_backend(backend) + + # Case 1: Directly a file path (str) or file-like object or raw bytes. + # If a file-like object, ensure the pointer is at the beginning. + if hasattr(waveforms_obj, "seek"): + waveforms_obj.seek(0) + + if isinstance(waveforms_obj, (str, BytesIO, bytes)): + # If raw bytes, wrap them in a BytesIO. + if isinstance(waveforms_obj, bytes): + waveforms_obj = BytesIO(waveforms_obj) + waveforms_obj.seek(0) + audio, _ = audio_io.load(waveforms_obj) + # Case 2: A dict with more options. Only works with file paths. + else: + path = waveforms_obj["file"] + start = waveforms_obj.get("start", 0) + # To match past SB behavior, `start == stop` or omitted `stop` means to + # load all frames from `start` to the file end. + stop = waveforms_obj.get("stop", start) + + if start < 0: + raise ValueError( + f"Invalid sample range (start < 0): {start}..{stop}!" + ) + + if stop < start: + # Could occur if the user tried one of two things: + # - specify a negative value as an attempt to index from the end; + # - specify -1 as an attempt to load up to the last sample. + raise ValueError( + f"Invalid sample range (stop < start): {start}..{stop}!\n" + 'Hint: Omit "stop" if you want to read to the end of file.' + ) + + # Requested to load until a specific frame? + if start != stop: + num_frames = stop - start + audio, fs = audio_io.load( + path, num_frames=num_frames, frame_offset=start + ) + else: + # Load to the end. + audio, fs = audio_io.load(path, frame_offset=start) + + audio = audio.transpose(0, 1) + return audio.squeeze(1) + + +def read_audio_multichannel(waveforms_obj, backend=None): + """General audio loading, based on a custom notation. + + Expected use case is in conjunction with Datasets + specified by JSON. + + The custom notation: + + The annotation can be just a path to a file: + "/path/to/wav1.wav" + + Multiple (possibly multi-channel) files can be specified, as long as they + have the same length: + {"files": [ + "/path/to/wav1.wav", + "/path/to/wav2.wav" + ] + } + + Or you can specify a single file more succinctly: + {"files": "/path/to/wav2.wav"} + + Offset number samples and stop number samples also can be specified to read + only a segment within the files. + {"files": [ + "/path/to/wav1.wav", + "/path/to/wav2.wav" + ] + "start": 8000 + "stop": 16000 + } + + Arguments + --------- + waveforms_obj : str, dict + Audio reading annotation, see above for format. + backend : str, optional + Audio backend to use for loading the audio file. Must be one of + 'ffmpeg', 'sox', 'soundfile' or None. If None, uses torchaudio's default backend. + + Raises + ------ + ValueError + If the `backend` is not one of the allowed values. + Must be one of [None, 'ffmpeg', 'sox', 'soundfile']. + + Returns + ------- + torch.Tensor + Audio tensor with shape: (samples, ). + + Example + ------- + >>> dummywav = torch.rand(16000, 2) + >>> import os + >>> tmpfile = str(getfixture("tmpdir") / "wave.wav") + >>> write_audio(tmpfile, dummywav, 16000) + >>> asr_example = {"wav": tmpfile, "spk_id": "foo", "words": "foo bar"} + >>> loaded = read_audio(asr_example["wav"]) + >>> loaded.allclose( + ... dummywav.squeeze(0), atol=1e-4 + ... ) # replace with eq with sox_io backend + True + """ + validate_backend(backend) + + # Case 1: Directly a file path (str) or file-like object or raw bytes. + # If a file-like object, ensure the pointer is at the beginning. + if hasattr(waveforms_obj, "seek"): + waveforms_obj.seek(0) + + if isinstance(waveforms_obj, (str, BytesIO, bytes)): + # If raw bytes, wrap them in a BytesIO. + if isinstance(waveforms_obj, bytes): + waveforms_obj = BytesIO(waveforms_obj) + waveforms_obj.seek(0) + audio, _ = audio_io.load(waveforms_obj) + return audio.transpose(0, 1) + + # Case 2: A dict with more options. Only works with file paths. + files = waveforms_obj["files"] + if not isinstance(files, list): + files = [files] + + waveforms = [] + start = waveforms_obj.get("start", 0) + # Default stop to start -> if not specified, num_frames becomes 0, + # which is the torchaudio default + stop = waveforms_obj.get("stop", start - 1) + num_frames = stop - start + for f in files: + audio, fs = audio_io.load(f, num_frames=num_frames, frame_offset=start) + waveforms.append(audio) + + out = torch.cat(waveforms, 0) + return out.transpose(0, 1) + + +def write_audio(filepath, audio, samplerate): + """Write audio on disk. It is basically a wrapper to support saving + audio signals in the speechbrain format (audio, channels). + + Arguments + --------- + filepath: path + Path where to save the audio file. + audio : torch.Tensor + Audio file in the expected speechbrain format (signal, channels). + samplerate: int + Sample rate (e.g., 16000). + + + Example + ------- + >>> import os + >>> tmpfile = str(getfixture("tmpdir") / "wave.wav") + >>> dummywav = torch.rand(16000, 2) + >>> write_audio(tmpfile, dummywav, 16000) + >>> loaded = read_audio(tmpfile) + >>> loaded.allclose( + ... dummywav, atol=1e-4 + ... ) # replace with eq with sox_io backend + True + """ + if len(audio.shape) == 2: + audio = audio.transpose(0, 1) + elif len(audio.shape) == 1: + audio = audio.unsqueeze(0) + + audio_io.save(filepath, audio, samplerate) + + +def load_pickle(pickle_path): + """Utility function for loading .pkl pickle files. + + Arguments + --------- + pickle_path : str + Path to pickle file. + + Returns + ------- + out : object + Python object loaded from pickle. + """ + with open(pickle_path, "rb") as f: + out = pickle.load(f) + return out + + +def to_floatTensor(x: Union[list, tuple, np.ndarray]): + """ + Arguments + --------- + x : (list, tuple, np.ndarray) + Input data to be converted to torch float. + + Returns + ------- + tensor : torch.Tensor + Data now in torch.tensor float datatype. + """ + if isinstance(x, torch.Tensor): + return x.float() + if isinstance(x, np.ndarray): + return torch.from_numpy(x).float() + else: + return torch.tensor(x, dtype=torch.float) + + +def to_doubleTensor(x: Union[list, tuple, np.ndarray]): + """ + Arguments + --------- + x : (list, tuple, np.ndarray) + Input data to be converted to torch double. + + Returns + ------- + tensor : torch.Tensor + Data now in torch.tensor double datatype. + """ + if isinstance(x, torch.Tensor): + return x.double() + if isinstance(x, np.ndarray): + return torch.from_numpy(x).double() + else: + return torch.tensor(x, dtype=torch.double) + + +def to_longTensor(x: Union[list, tuple, np.ndarray]): + """ + Arguments + --------- + x : (list, tuple, np.ndarray) + Input data to be converted to torch long. + + Returns + ------- + tensor : torch.Tensor + Data now in torch.tensor long datatype. + """ + if isinstance(x, torch.Tensor): + return x.long() + if isinstance(x, np.ndarray): + return torch.from_numpy(x).long() + else: + return torch.tensor(x, dtype=torch.long) + + +def convert_index_to_lab(batch, ind2lab): + """Convert a batch of integer IDs to string labels. + + Arguments + --------- + batch : list + List of lists, a batch of sequences. + ind2lab : dict + Mapping from integer IDs to labels. + + Returns + ------- + list + List of lists, same size as batch, with labels from ind2lab. + + Example + ------- + >>> ind2lab = {1: "h", 2: "e", 3: "l", 4: "o"} + >>> out = convert_index_to_lab([[4, 1], [1, 2, 3, 3, 4]], ind2lab) + >>> for seq in out: + ... print("".join(seq)) + oh + hello + """ + return [[ind2lab[int(index)] for index in seq] for seq in batch] + + +def relative_time_to_absolute(batch, relative_lens, rate): + """Converts SpeechBrain style relative length to the absolute duration. + + Operates on batch level. + + Arguments + --------- + batch : torch.Tensor + Sequences to determine the duration for. + relative_lens : torch.Tensor + The relative length of each sequence in batch. The longest sequence in + the batch needs to have relative length 1.0. + rate : float + The rate at which sequence elements occur in real-world time. Sample + rate, if batch is raw wavs (recommended) or 1/frame_shift if batch is + features. This has to have 1/s as the unit. + + Returns + ------- + torch.Tensor + Duration of each sequence in seconds. + + Example + ------- + >>> batch = torch.ones(2, 16000) + >>> relative_lens = torch.tensor([3.0 / 4.0, 1.0]) + >>> rate = 16000 + >>> print(relative_time_to_absolute(batch, relative_lens, rate)) + tensor([0.7500, 1.0000]) + """ + max_len = batch.shape[1] + durations = torch.round(relative_lens * max_len) / rate + return durations + + +class IterativeCSVWriter: + """Write CSV files a line at a time. + + Arguments + --------- + outstream : file-object + A writeable stream + data_fields : list + List of the optional keys to write. Each key will be expanded to the + SpeechBrain format, producing three fields: key, key_format, key_opts. + defaults : dict + Mapping from CSV key to corresponding default value. + + Example + ------- + >>> import io + >>> f = io.StringIO() + >>> writer = IterativeCSVWriter(f, ["phn"]) + >>> print(f.getvalue()) + ID,duration,phn,phn_format,phn_opts + >>> writer.write("UTT1", 2.5, "sil hh ee ll ll oo sil", "string", "") + >>> print(f.getvalue()) + ID,duration,phn,phn_format,phn_opts + UTT1,2.5,sil hh ee ll ll oo sil,string, + >>> writer.write( + ... ID="UTT2", phn="sil ww oo rr ll dd sil", phn_format="string" + ... ) + >>> print(f.getvalue()) + ID,duration,phn,phn_format,phn_opts + UTT1,2.5,sil hh ee ll ll oo sil,string, + UTT2,,sil ww oo rr ll dd sil,string, + >>> writer.set_default("phn_format", "string") + >>> writer.write_batch(ID=["UTT3", "UTT4"], phn=["ff oo oo", "bb aa rr"]) + >>> print(f.getvalue()) + ID,duration,phn,phn_format,phn_opts + UTT1,2.5,sil hh ee ll ll oo sil,string, + UTT2,,sil ww oo rr ll dd sil,string, + UTT3,,ff oo oo,string, + UTT4,,bb aa rr,string, + """ + + def __init__(self, outstream, data_fields, defaults=None): + if defaults is None: + defaults = {} + self._outstream = outstream + self.fields = ["ID", "duration"] + self._expand_data_fields(data_fields) + self.defaults = defaults + self._outstream.write(",".join(self.fields)) + + def set_default(self, field, value): + """Sets a default value for the given CSV field. + + Arguments + --------- + field : str + A field in the CSV. + value : str + The default value. + """ + if field not in self.fields: + raise ValueError(f"{field} is not a field in this CSV!") + self.defaults[field] = value + + def write(self, *args, **kwargs): + """Writes one data line into the CSV. + + Arguments + --------- + *args : tuple + Supply every field with a value in positional form OR. + **kwargs : dict + Supply certain fields by key. The ID field is mandatory for all + lines, but others can be left empty. + """ + if args: + if len(args) != len(self.fields): + raise ValueError("Need consistent fields") + to_write = [str(arg) for arg in args] + if kwargs: + raise ValueError( + "Use either positional fields or named fields, " + "but not both." + ) + else: + if kwargs: + if "ID" not in kwargs: + raise ValueError("I'll need to see some ID") + full_vals = self.defaults.copy() + full_vals.update(kwargs) + to_write = [ + str(full_vals.get(field, "")) for field in self.fields + ] + else: + raise ValueError( + "Use either positional fields or named fields." + ) + self._outstream.write("\n") + self._outstream.write(",".join(to_write)) + + def write_batch(self, *args, **kwargs): + """Writes a batch of lines into the CSV. + + Here each argument should be a list with the same length. + + Arguments + --------- + *args : tuple + Supply every field with a value in positional form OR. + **kwargs : dict + Supply certain fields by key. The ID field is mandatory for all + lines, but others can be left empty. + """ + if args and kwargs: + raise ValueError( + "Use either positional fields or named fields, but not both." + ) + if args: + if len(args) != len(self.fields): + raise ValueError("Need consistent fields") + for arg_row in zip(*args): + self.write(*arg_row) + if kwargs: + if "ID" not in kwargs: + raise ValueError("I'll need to see some ID") + keys = kwargs.keys() + for value_row in zip(*kwargs.values()): + kwarg_row = dict(zip(keys, value_row)) + self.write(**kwarg_row) + + @staticmethod + def _expand_data_fields(data_fields): + expanded = [] + for data_field in data_fields: + expanded.append(data_field) + expanded.append(data_field + "_format") + expanded.append(data_field + "_opts") + return expanded + + +def write_txt_file(data, filename, sampling_rate=None): + """Write data in text format. + + Arguments + --------- + data : str, list, torch.Tensor, numpy.ndarray + The data to write in the text file. + filename : str + Path to file where to write the data. + sampling_rate : None + Not used, just here for interface compatibility. + + Example + ------- + >>> tmpdir = getfixture("tmpdir") + >>> signal = torch.tensor([1, 2, 3, 4]) + >>> write_txt_file(signal, tmpdir / "example.txt") + """ + del sampling_rate # Not used. + # Check if the path of filename exists + os.makedirs(os.path.dirname(filename), exist_ok=True) + with open(filename, "w", encoding="utf-8") as fout: + if isinstance(data, torch.Tensor): + data = data.tolist() + if isinstance(data, np.ndarray): + data = data.tolist() + if isinstance(data, list): + for line in data: + print(line, file=fout) + if isinstance(data, str): + print(data, file=fout) + + +def write_stdout(data, filename=None, sampling_rate=None): + """Write data to standard output. + + Arguments + --------- + data : str, list, torch.Tensor, numpy.ndarray + The data to write in the text file. + filename : None + Not used, just here for compatibility. + sampling_rate : None + Not used, just here for compatibility. + + Example + ------- + >>> tmpdir = getfixture("tmpdir") + >>> signal = torch.tensor([[1, 2, 3, 4]]) + >>> write_stdout(signal, tmpdir / "example.txt") + [1, 2, 3, 4] + """ + # Managing Torch.Tensor + if isinstance(data, torch.Tensor): + data = data.tolist() + # Managing np.ndarray + if isinstance(data, np.ndarray): + data = data.tolist() + if isinstance(data, list): + for line in data: + print(line) + if isinstance(data, str): + print(data) + + +def length_to_mask(length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + + Example + ------- + >>> length = torch.Tensor([1, 2, 3]) + >>> mask = length_to_mask(length) + >>> mask + tensor([[1., 0., 0.], + [1., 1., 0.], + [1., 1., 1.]]) + """ + assert len(length.shape) == 1 + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = torch.arange( + max_len, device=length.device, dtype=length.dtype + ).expand(len(length), max_len) < length.unsqueeze(1) + + if dtype is None: + dtype = length.dtype + + if device is None: + device = length.device + + mask = torch.as_tensor(mask, dtype=dtype, device=device) + return mask + + +def read_kaldi_lab(kaldi_ali, kaldi_lab_opts): + """Read labels in kaldi format. + + Uses kaldi IO. + + Arguments + --------- + kaldi_ali : str + Path to directory where kaldi alignments are stored. + kaldi_lab_opts : str + A string that contains the options for reading the kaldi alignments. + + Returns + ------- + lab : dict + A dictionary containing the labels. + + Note + ---- + This depends on kaldi-io-for-python. Install it separately. + See: https://github.com/vesis84/kaldi-io-for-python + + Example + ------- + This example requires kaldi files. + ``` + lab_folder = "/home/kaldi/egs/TIMIT/s5/exp/dnn4_pretrain-dbn_dnn_ali" + read_kaldi_lab(lab_folder, "ali-to-pdf") + ``` + """ + # EXTRA TOOLS + try: + import kaldi_io + except ImportError: + raise ImportError("Could not import kaldi_io. Install it to use this.") + # Reading the Kaldi labels + lab = { + k: v + for k, v in kaldi_io.read_vec_int_ark( + "gunzip -c " + + kaldi_ali + + "/ali*.gz | " + + kaldi_lab_opts + + " " + + kaldi_ali + + "/final.mdl ark:- ark:-|" + ) + } + return lab + + +def get_md5(file): + """Get the md5 checksum of an input file. + + Arguments + --------- + file : str + Path to file for which compute the checksum. + + Returns + ------- + md5 + Checksum for the given filepath. + + Example + ------- + >>> get_md5("tests/samples/single-mic/example1.wav") + 'c482d0081ca35302d30d12f1136c34e5' + """ + # Lets read stuff in 64kb chunks! + BUF_SIZE = 65536 + md5 = hashlib.md5() + # Computing md5 + with open(file, "rb") as f: + while True: + data = f.read(BUF_SIZE) + if not data: + break + md5.update(data) + return md5.hexdigest() + + +def save_md5(files, out_file): + """Saves the md5 of a list of input files as a pickled dict into a file. + + Arguments + --------- + files : list + List of input files from which we will compute the md5. + out_file : str + The path where to store the output pkl file. + + Example + ------- + >>> files = ["tests/samples/single-mic/example1.wav"] + >>> tmpdir = getfixture("tmpdir") + >>> save_md5(files, tmpdir / "md5.pkl") + """ + # Initialization of the dictionary + md5_dict = {} + # Computing md5 for all the files in the list + for file in files: + md5_dict[file] = get_md5(file) + # Saving dictionary in pkl format + save_pkl(md5_dict, out_file) + + +def save_pkl(obj, file): + """Save an object in pkl format. + + Arguments + --------- + obj : object + Object to save in pkl format + file : str + Path to the output file + + Example + ------- + >>> tmpfile = getfixture("tmpdir") / "example.pkl" + >>> save_pkl([1, 2, 3, 4, 5], tmpfile) + >>> load_pkl(tmpfile) + [1, 2, 3, 4, 5] + """ + with open(file, "wb") as f: + pickle.dump(obj, f) + + +def load_pkl(file): + """Loads a pkl file. + + For an example, see `save_pkl`. + + Arguments + --------- + file : str + Path to the input pkl file. + + Returns + ------- + The loaded object. + """ + + # Deals with the situation where two processes are trying + # to access the same label dictionary by creating a lock + count = 100 + while count > 0: + if os.path.isfile(file + ".lock"): + time.sleep(1) + count -= 1 + else: + break + + try: + open(file + ".lock", "w", encoding="utf-8").close() + with open(file, "rb") as f: + return pickle.load(f) + finally: + if os.path.isfile(file + ".lock"): + os.remove(file + ".lock") + + +def prepend_bos_token(label, bos_index): + """Create labels with token at the beginning. + + Arguments + --------- + label : torch.IntTensor + Containing the original labels. Must be of size: [batch_size, max_length]. + bos_index : int + The index for token. + + Returns + ------- + new_label : tensor + The new label with at the beginning. + + Example + ------- + >>> label = torch.LongTensor([[1, 0, 0], [2, 3, 0], [4, 5, 6]]) + >>> new_label = prepend_bos_token(label, bos_index=7) + >>> new_label + tensor([[7, 1, 0, 0], + [7, 2, 3, 0], + [7, 4, 5, 6]]) + """ + new_label = label.long().clone() + batch_size = label.shape[0] + + bos = new_label.new_zeros(batch_size, 1).fill_(bos_index) + new_label = torch.cat([bos, new_label], dim=1) + return new_label + + +def append_eos_token(label, length, eos_index): + """Create labels with token appended. + + Arguments + --------- + label : torch.IntTensor + Containing the original labels. Must be of size: [batch_size, max_length] + length : torch.LongTensor + Containing the original length of each label sequences. Must be 1D. + eos_index : int + The index for token. + + Returns + ------- + new_label : tensor + The new label with appended. + + Example + ------- + >>> label = torch.IntTensor([[1, 0, 0], [2, 3, 0], [4, 5, 6]]) + >>> length = torch.LongTensor([1, 2, 3]) + >>> new_label = append_eos_token(label, length, eos_index=7) + >>> new_label + tensor([[1, 7, 0, 0], + [2, 3, 7, 0], + [4, 5, 6, 7]], dtype=torch.int32) + """ + new_label = label.int().clone() + batch_size = label.shape[0] + + pad = new_label.new_zeros(batch_size, 1) + new_label = torch.cat([new_label, pad], dim=1) + new_label[torch.arange(batch_size), length.long()] = eos_index + return new_label + + +def merge_char(sequences, space="_"): + """Merge characters sequences into word sequences. + + Arguments + --------- + sequences : list + Each item contains a list, and this list contains a character sequence. + space : string + The token represents space. Default: _ + + Returns + ------- + The list contains word sequences for each sentence. + + Example + ------- + >>> sequences = [ + ... ["a", "b", "_", "c", "_", "d", "e"], + ... ["e", "f", "g", "_", "h", "i"], + ... ] + >>> results = merge_char(sequences) + >>> results + [['ab', 'c', 'de'], ['efg', 'hi']] + """ + results = [] + for seq in sequences: + words = "".join(seq).split(space) + results.append(words) + return results + + +def merge_csvs(data_folder, csv_lst, merged_csv): + """Merging several csv files into one file. + + Arguments + --------- + data_folder : string + The folder to store csv files to be merged and after merging. + csv_lst : list + Filenames of csv file to be merged. + merged_csv : string + The filename to write the merged csv file. + + Example + ------- + >>> tmpdir = getfixture("tmpdir") + >>> os.symlink( + ... os.path.realpath("tests/samples/annotation/speech.csv"), + ... tmpdir / "speech.csv", + ... ) + >>> merge_csvs(tmpdir, ["speech.csv", "speech.csv"], "test_csv_merge.csv") + """ + write_path = os.path.join(data_folder, merged_csv) + if os.path.isfile(write_path): + logger.info("Skipping merging. Completed in previous run.") + with open( + os.path.join(data_folder, csv_lst[0]), newline="", encoding="utf-8" + ) as f: + header = f.readline() + lines = [] + for csv_file in csv_lst: + with open( + os.path.join(data_folder, csv_file), newline="", encoding="utf-8" + ) as f: + for i, line in enumerate(f): + if i == 0: + # Checking header + if line != header: + raise ValueError( + f"Different header for {csv_lst[0]} and {csv}." + ) + continue + lines.append(line) + with open(write_path, "w", encoding="utf-8") as f: + f.write(header) + for line in lines: + f.write(line) + logger.info(f"{write_path} is created.") + + +def split_word(sequences, space="_"): + """Split word sequences into character sequences. + + Arguments + --------- + sequences: list + Each item contains a list, and this list contains a words sequence. + space: string + The token represents space. Default: _ + + Returns + ------- + The list contains word sequences for each sentence. + + Example + ------- + >>> sequences = [["ab", "c", "de"], ["efg", "hi"]] + >>> results = split_word(sequences) + >>> results + [['a', 'b', '_', 'c', '_', 'd', 'e'], ['e', 'f', 'g', '_', 'h', 'i']] + """ + results = [] + for seq in sequences: + chars = list(space.join(seq)) + results.append(chars) + return results + + +def clean_padding_(tensor, length, len_dim=1, mask_value=0.0): + """Sets the value of any padding on the specified tensor to mask_value. + + For instance, this can be used to zero out the outputs of an autoencoder + during training past the specified length. + + This is an in-place operation + + Arguments + --------- + tensor: torch.Tensor + a tensor of arbitrary dimension + length: torch.Tensor + a 1-D tensor of lengths + len_dim: int + the dimension representing the length + mask_value: mixed + the value to be assigned to padding positions + + Example + ------- + >>> import torch + >>> x = torch.arange(5).unsqueeze(0).repeat(3, 1) + >>> x = x + torch.arange(3).unsqueeze(-1) + >>> x + tensor([[0, 1, 2, 3, 4], + [1, 2, 3, 4, 5], + [2, 3, 4, 5, 6]]) + >>> length = torch.tensor([0.4, 1.0, 0.6]) + >>> clean_padding_(x, length=length, mask_value=10.0) + >>> x + tensor([[ 0, 1, 10, 10, 10], + [ 1, 2, 3, 4, 5], + [ 2, 3, 4, 10, 10]]) + >>> x = torch.arange(5)[None, :, None].repeat(3, 1, 2) + >>> x = x + torch.arange(3)[:, None, None] + >>> x = x * torch.arange(1, 3)[None, None, :] + >>> x = x.transpose(1, 2) + >>> x + tensor([[[ 0, 1, 2, 3, 4], + [ 0, 2, 4, 6, 8]], + + [[ 1, 2, 3, 4, 5], + [ 2, 4, 6, 8, 10]], + + [[ 2, 3, 4, 5, 6], + [ 4, 6, 8, 10, 12]]]) + >>> clean_padding_(x, length=length, mask_value=10.0, len_dim=2) + >>> x + tensor([[[ 0, 1, 10, 10, 10], + [ 0, 2, 10, 10, 10]], + + [[ 1, 2, 3, 4, 5], + [ 2, 4, 6, 8, 10]], + + [[ 2, 3, 4, 10, 10], + [ 4, 6, 8, 10, 10]]]) + """ + max_len = tensor.size(len_dim) + mask = length_to_mask(length * max_len, max_len).bool() + mask_unsq = mask[(...,) + (None,) * (tensor.dim() - 2)] + mask_t = mask_unsq.transpose(1, len_dim).expand_as(tensor) + tensor[~mask_t] = mask_value + + +def clean_padding(tensor, length, len_dim=1, mask_value=0.0): + """Sets the value of any padding on the specified tensor to mask_value. + + For instance, this can be used to zero out the outputs of an autoencoder + during training past the specified length. + + This version of the operation does not modify the original tensor + + Arguments + --------- + tensor: torch.Tensor + a tensor of arbitrary dimension + length: torch.Tensor + a 1-D tensor of lengths + len_dim: int + the dimension representing the length + mask_value: mixed + the value to be assigned to padding positions + + Returns + ------- + result: torch.Tensor + Tensor with updated padding. + + Example + ------- + >>> import torch + >>> x = torch.arange(5).unsqueeze(0).repeat(3, 1) + >>> x = x + torch.arange(3).unsqueeze(-1) + >>> x + tensor([[0, 1, 2, 3, 4], + [1, 2, 3, 4, 5], + [2, 3, 4, 5, 6]]) + >>> length = torch.tensor([0.4, 1.0, 0.6]) + >>> x_p = clean_padding(x, length=length, mask_value=10.0) + >>> x_p + tensor([[ 0, 1, 10, 10, 10], + [ 1, 2, 3, 4, 5], + [ 2, 3, 4, 10, 10]]) + >>> x = torch.arange(5)[None, :, None].repeat(3, 1, 2) + >>> x = x + torch.arange(3)[:, None, None] + >>> x = x * torch.arange(1, 3)[None, None, :] + >>> x = x.transpose(1, 2) + >>> x + tensor([[[ 0, 1, 2, 3, 4], + [ 0, 2, 4, 6, 8]], + + [[ 1, 2, 3, 4, 5], + [ 2, 4, 6, 8, 10]], + + [[ 2, 3, 4, 5, 6], + [ 4, 6, 8, 10, 12]]]) + >>> x_p = clean_padding(x, length=length, mask_value=10.0, len_dim=2) + >>> x_p + tensor([[[ 0, 1, 10, 10, 10], + [ 0, 2, 10, 10, 10]], + + [[ 1, 2, 3, 4, 5], + [ 2, 4, 6, 8, 10]], + + [[ 2, 3, 4, 10, 10], + [ 4, 6, 8, 10, 10]]]) + """ + + result = tensor.clone() + clean_padding_(result, length, len_dim, mask_value) + return result + + +def extract_concepts_values(sequences, keep_values, tag_in, tag_out, space): + """keep the semantic concepts and values for evaluation. + + Arguments + --------- + sequences: list + Each item contains a list, and this list contains a character sequence. + keep_values: bool + If True, keep the values. If not don't. + tag_in: char + Indicates the start of the concept. + tag_out: char + Indicates the end of the concept. + space: string + The token represents space. Default: _ + + Returns + ------- + The list contains concept and value sequences for each sentence. + + Example + ------- + >>> sequences = [ + ... [ + ... "", + ... "_", + ... "n", + ... "o", + ... "_", + ... ">", + ... "_", + ... "", + ... "_", + ... "L", + ... "e", + ... "_", + ... "M", + ... "a", + ... "n", + ... "s", + ... "_", + ... ">", + ... ], + ... ["", "_", "s", "i", "_", ">"], + ... ["v", "a", "_", "b", "e", "n", "e"], + ... ] + >>> results = extract_concepts_values(sequences, True, "<", ">", "_") + >>> results + [[' no', ' Le Mans'], [' si'], ['']] + """ + results = [] + for sequence in sequences: + # ['_no_>__Le_Mans_>'] + sequence = "".join(sequence) + # ['','no','>','','Le','Mans,'>'] + sequence = sequence.split(space) + processed_sequence = [] + value = [] # If previous sequence value never used because never had a tag_out + kept = "" # If previous sequence kept never used because never had a tag_out + concept_open = False + for word in sequence: + if re.match(tag_in, word): + # If not close tag but new tag open + if concept_open and keep_values: + if len(value) != 0: + kept += " " + " ".join(value) + concept_open = False + processed_sequence.append(kept) + kept = word # 1st loop: '' + value = [] # Concept's value + concept_open = True # Trying to catch the concept's value + # If we want the CER + if not keep_values: + processed_sequence.append(kept) # Add the kept concept + # If we have a tag_out, had a concept, and want the values for CVER + elif re.match(tag_out, word) and concept_open and keep_values: + # If we have a value + if len(value) != 0: + kept += " " + " ".join( + value + ) # 1st loop: '' + ' ' + 'no' + concept_open = False # Wait for a new tag_in to pursue + processed_sequence.append(kept) # Add the kept concept + value + elif concept_open: + value.append(word) # 1st loop: 'no' + # If not close tag but end sequence + if concept_open and keep_values: + if len(value) != 0: + kept += " " + " ".join(value) + concept_open = False + processed_sequence.append(kept) + if len(processed_sequence) == 0: + processed_sequence.append("") + results.append(processed_sequence) + return results diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/dataloader.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/dataloader.py new file mode 100644 index 000000000..fb0aaa485 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/dataloader.py @@ -0,0 +1,420 @@ +"""PyTorch compatible DataLoaders + +Essentially we extend PyTorch DataLoader by adding the ability to save the +data loading state, so that a checkpoint may be saved in the middle of an +epoch. + +Example +------- +>>> import torch +>>> from speechbrain.utils.checkpoints import Checkpointer +>>> # An example "dataset" and its loader +>>> dataset = torch.randn(10, 1) +>>> dataloader = SaveableDataLoader(dataset, num_workers=3) +>>> # Setup the checkpointer: +>>> tmpdir = getfixture("tmpdir") +>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader}) +>>> # Iterate: +>>> for i, data_point in enumerate(dataloader): +... # Here you would process the data: +... rainfall_amount_prediction = data_point * 4.0 +... # Now, imagine the experiment gets killed on the fifth batch: +... if i == 4: +... break +... # Luckily, you had just saved a checkpoint: +... if i == 3: +... _ = checkpointer.save_checkpoint(end_of_epoch=False) +>>> # So when you restart the experiment: +>>> new_dataloader = SaveableDataLoader(dataset, num_workers=3) +>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader}) +>>> _ = new_checkpointer.recover_if_possible() +>>> # The dataloader fast-forwards to the position where we left off: +>>> assert next(iter(new_dataloader)) == dataset[4] + +Authors: + * Aku Rouhe 2020 +""" + +import functools +import os +import warnings + +from torch.utils.data import DataLoader, DistributedSampler, IterableDataset +from torch.utils.data.dataloader import _BaseDataLoaderIter + +from speechbrain.dataio.batch import BatchsizeGuesser, PaddedBatch +from speechbrain.dataio.dataset import DynamicItemDataset +from speechbrain.dataio.sampler import ( + DistributedSamplerWrapper, + ReproducibleRandomSampler, +) +from speechbrain.utils.checkpoints import ( + mark_as_loader, + mark_as_saver, + register_checkpoint_hooks, +) +from speechbrain.utils.logger import get_logger + +# Optional support for webdataset +try: + import webdataset as wds + from importlib_metadata import version + + WDS_AVAILABLE = True + + # Use appropriate class based on webdataset version + if version("webdataset")[0:4] == "0.1.": + WDS_CLASS = wds.dataset.Composable + else: + WDS_CLASS = wds.DataPipeline +except ImportError: + WDS_AVAILABLE = False + +logger = get_logger(__name__) + + +def distributed_loader_specifics( + distributed_launch, rank, dataset, loader_kwargs +): + """Prepare loader_kwargs for DDP when necessary. + + Arguments + --------- + distributed_launch : bool + DDP flag + rank : int + node rank in DDP + dataset : Dataset + The dataset to make a DataLoader for. + loader_kwargs : dict + Keyword args to DataLoader, see PyTorch DataLoader for + options. + + Returns + ------- + loader_kwargs + augmented keyword args to DataLoader + """ + sampler = loader_kwargs.get("sampler", None) + shuffle = loader_kwargs.get("shuffle", False) + # Possibly make a DistributedSampler or a wrapper for some other sampler + if distributed_launch and not isinstance(dataset, IterableDataset): + drop_last = loader_kwargs.get("drop_last", False) + # num_replicas arg is equal to world_size + # and retrieved automatically within + # DistributedSampler obj. + if sampler is not None: + sampler = DistributedSamplerWrapper( + sampler, + rank=rank, + drop_last=drop_last, + shuffle=shuffle, + ) + + # with DistributedSamplerWrapper, one must disable shuffling for dataloader + loader_kwargs["shuffle"] = False + loader_kwargs["sampler"] = sampler + elif loader_kwargs.get("batch_sampler") is None: + # no sampler and batch-sampler + sampler = DistributedSampler( + dataset, + rank=rank, + drop_last=drop_last, + ) + + # with DistributedSamplerWrapper, one must disable shuffling for dataloader + loader_kwargs["shuffle"] = False + loader_kwargs["sampler"] = sampler + else: # batch_sampler was specified + sampler = DistributedSamplerWrapper( + loader_kwargs.get("batch_sampler", None), + rank=rank, + ) + loader_kwargs["batch_sampler"] = sampler + elif distributed_launch and isinstance(dataset, IterableDataset): + logger.warning( + "Cannot automatically solve distributed sampling " + "for IterableDataset." + ) + return loader_kwargs + + +def make_dataloader(dataset, looped_nominal_epoch=None, **loader_kwargs): + """Makes a basic DataLoader with SpeechBrain defaults. + + For DynamicItemDatasets (which return dicts), use + PaddedBatch as the default collate_fn. + + Shuffling gets implemented by ReproducibleRandomSampler. + + If the Dataset is not an IterableDataset, the DataLoader + is a SaveableDataLoader. + + If the Dataset is a webdataset.dataset.Composable, set default + batch_size = None. + + Can also loop over the underlying dataloader continuously, + and stop iterations at nominal epoch lengths. + + Arguments + --------- + dataset : Dataset + The dataset to make a DataLoader for. + looped_nominal_epoch : None, int + If an integer is given, loop the underlying DataLoader infinitely and + set a nominal epoch length in batches (or whatever the DataLoader + yields). + **loader_kwargs : dict + Keyword args to DataLoader, see PyTorch DataLoader for + options. + + Returns + ------- + DataLoader + If looped_nominal_epoch is None + LoopedLoader + If looped_nominal_epoch is not None + """ + # PaddedBatch as default collation for DynamicItemDataset + if "collate_fn" not in loader_kwargs and isinstance( + dataset, DynamicItemDataset + ): + loader_kwargs["collate_fn"] = PaddedBatch + # Reproducible random sampling + if loader_kwargs.get("shuffle", False): + if loader_kwargs.get("sampler") is not None: + raise ValueError( + "Cannot specify both shuffle=True and a " + "sampler in loader_kwargs" + ) + seed = int(os.environ.get("SB_GLOBAL_SEED", 563375142)) + sampler = ReproducibleRandomSampler(dataset, seed=seed) + loader_kwargs["sampler"] = sampler + # Should delete shuffle because you can't set both Sampler and + # shuffle + # NOTE: the dict of loader options may get used elsewhere! + # However, this del doesn't touch those because loader_kwargs comes + # from a **kwargs dict. + del loader_kwargs["shuffle"] + # With WDS it is recommended to do batching in the dataset itself, + # which requires batch_size = None in the DataLoader + if ( + WDS_AVAILABLE + and isinstance(dataset, WDS_CLASS) + and "batch_size" not in loader_kwargs + ): + loader_kwargs["batch_size"] = None + # Create the loader + if isinstance(dataset, IterableDataset): + dataloader = DataLoader(dataset, **loader_kwargs) + else: + dataloader = SaveableDataLoader(dataset, **loader_kwargs) + if looped_nominal_epoch is not None: + dataloader = LoopedLoader(dataloader, looped_nominal_epoch) + return dataloader + + +# We essentially want to make the DataLoader iterators able to skip ahead +# after checkpoint recovery +# This should be handled by the DataLoader iterators' base class. +# To make the implementation here a little more maintainable +# we decide to patch some PyTorch functionality + + +def __new_init(self, loader, *args, **kwargs): + self.__old_init__(loader, *args, **kwargs) + if ( + hasattr(loader, "_speechbrain_recovery_skip_to") + and loader._speechbrain_recovery_skip_to is not None + ): + # Fast forward the sampler iterator since we have recovered: + for i in range(loader._speechbrain_recovery_skip_to): + try: + next(self._sampler_iter) + except StopIteration: + MSG = "Tried to fast-forward Sampler after checkpoint " + f"recovery by {loader._speechbrain_recovery_skip_to} " + "indices, but now Sampler raised StopIteration after " + f"{i} indices. Ignoring this mismatch." + warnings.warn(MSG) + break + self._num_yielded = i + 1 + # Mark recovery as done: + loader._speechbrain_recovery_skip_to = None + + +def __new_reset(self, loader, first_iter=False, *args, **kwargs): + # On the first iteration, these have already normally been set by the init anyway. + # And we don't want to overwrite them if we've recovered + if not first_iter: + self._sampler_iter = iter(self._index_sampler) + self._num_yielded = 0 + self._IterableDataset_len_called = loader._IterableDataset_len_called + + +# functools.update_wrapper is meant for decorators, but it should basically +# preserve what we want: +functools.update_wrapper(__new_init, _BaseDataLoaderIter.__init__) +_BaseDataLoaderIter.__old_init__ = _BaseDataLoaderIter.__init__ +_BaseDataLoaderIter.__init__ = __new_init +if hasattr(_BaseDataLoaderIter, "_reset"): + _BaseDataLoaderIter._reset = __new_reset + + +@register_checkpoint_hooks +class SaveableDataLoader(DataLoader): + """A saveable version of the PyTorch DataLoader. + + See `torch.utils.data.DataLoader` for usage. This class should work exactly + like the PyTorch basic DataLoader, but this can be checkpointed with + SpeechBrain's Checkpointer. + + Note + ---- + 1. The saveability is implemented via some unfortunately slightly magical + means. + 2. The data loader cannot recover after entering __iter__. Normally this is + not a problem, as recovery should happen before training begins. However, + just before evaluation, it is also typical to recover the checkpoint at + which performance was the best. Thus, if a checkpoint is loaded after + entering __iter__, we just assume it is for this reason. A warning is + logged, but that is all. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if isinstance(self.dataset, IterableDataset): + logger.warning( + "SaveableDataLoader cannot save the position in an " + "IterableDataset. Save the position on the dataset itself." + ) + self._speechbrain_recovery_skip_to = None + self._speechbrain_iterator = None + + def __iter__(self): + iterator = super().__iter__() + # Keep a reference to the iterator, + # to be able to access the iterator._num_yielded value. + # Keep a full reference (keeping the iterator alive) + # rather than e.g. a weakref, as we may want to save a checkpoint + # after the iterator has been exhausted, but before the full epoch has + # ended (e.g. validation is still running) + self._speechbrain_iterator = iterator + return iterator + + @mark_as_saver + def _speechbrain_save(self, path): + if isinstance(self.dataset, IterableDataset): + logger.warning( + "Warning again: a checkpoint was requested on " + "SaveableDataLoader, but the dataset is an IterableDataset. " + "Cannot save the position in an IterableDataset. Not raising " + "an error; assuming that you know what you're doing." + ) + if self._speechbrain_iterator is None: + to_save = None + else: + to_save = self._speechbrain_iterator._num_yielded + with open(path, "w", encoding="utf-8") as fo: + fo.write(str(to_save)) + + @mark_as_loader + def _speechbrain_load(self, path, end_of_epoch): + if self._speechbrain_iterator is not None: + logger.debug( + "SaveableDataLoader was requested to load a " + "checkpoint, but the DataLoader has already been " + "iterated. The DataLoader file will be ignored. " + "This is normal in evaluation, when a checkpoint is " + "loaded just to retrieve the best model." + ) + return + if end_of_epoch: + # Don't load at end of epoch, as we actually want to start a fresh + # epoch iteration next. + return + with open(path, encoding="utf-8") as fi: + saved = fi.read() + if saved == str(None): + # Saved at a point where e.g. an iterator did not yet exist. + return + else: + self._speechbrain_recovery_skip_to = int(saved) + + +@register_checkpoint_hooks +class LoopedLoader: + """Loops an underlying iterable indefinitely, with nominal epoch lengths + + This is useful for working with IterableDatasets, and particularly + webdataset-style loading. We recommend using ``.repeat()`` on the + webdataset IterableDataset instance, so that the underlying dataloader + naturally continues for ever. + + Arguments + --------- + loader : iterable + A DataLoader or other iterable that is looped repeatedly. + epoch_length : int + The length of the nominal epoch. After this many steps, raises + StopIteration + batchsize_fn : callable + Function for determining batch size, default ``BatchsizeGuesser`` + """ + + def __init__(self, loader, epoch_length, batchsize_fn=None): + self.loader = loader + self.iterator = None + self.epoch_length = epoch_length + self.step = 0 # Step in epoch + self.total_steps = 0 # Total steps ever + self.total_samples = 0 # Total samples seen on this process + if batchsize_fn is None: + self.batchsize_fn = BatchsizeGuesser() + + def __iter__(self): + if self.iterator is None: + self.iterator = iter(self.loader) + return self + + def __next__(self): + if self.step < self.epoch_length: + self.step += 1 + self.total_steps += 1 + try: + batch = next(self.iterator) + except StopIteration: + self.iterator = iter(self.loader) + batch = next(self.iterator) + self.total_samples += self.batchsize_fn(batch) + return batch + else: + self.step = 0 + raise StopIteration + + def __len__(self): + return self.epoch_length + + @mark_as_saver + def save(self, path): + """Saves the needed information.""" + with open(path, "w", encoding="utf-8") as fo: + print(self.step, file=fo) + print(self.total_steps, file=fo) + print(self.total_samples, file=fo) + + @mark_as_loader + def load(self, path, end_of_epoch=True): + """Loads the needed information.""" + with open(path, encoding="utf-8") as fi: + self.step = int(fi.readline().strip()) + self.total_steps = int(fi.readline().strip()) + self.total_samples = int(fi.readline().strip()) + if not end_of_epoch and self.step == 0 and self.total_steps > 0: + # Step has been set to 0 at the end of iteration, + # so return it to epoch_length, so that first iteration + # of this will immediately raise StopIteration. + # Basically, this can happen when e.g. the main training + # loop has already finished but there is a checkpoint in the + # middle of validation. + self.step = self.epoch_length diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/dataset.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/dataset.py new file mode 100644 index 000000000..1ec508385 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/dataset.py @@ -0,0 +1,546 @@ +"""Dataset examples for loading individual data points + +Authors + * Aku Rouhe 2020 + * Samuele Cornell 2020 +""" + +import contextlib +import copy +import math +from types import MethodType + +import tqdm +from torch.utils.data import Dataset + +from speechbrain.dataio.dataio import load_data_csv, load_data_json +from speechbrain.utils.data_pipeline import DataPipeline +from speechbrain.utils.data_utils import batch_shuffle +from speechbrain.utils.logger import get_logger + +logger = get_logger(__name__) + + +class DynamicItemDataset(Dataset): + """Dataset that reads, wrangles, and produces dicts. + + Each data point dict provides some items (by key), for example, a path to a + wavefile with the key "wav_file". When a data point is fetched from this + Dataset, more items are produced dynamically, based on pre-existing items + and other dynamic created items. For example, a dynamic item could take the + wavfile path and load the audio from the disk. + + The dynamic items can depend on other dynamic items: a suitable evaluation + order is used automatically, as long as there are no circular dependencies. + + A specified list of keys is collected in the output dict. These can be items + in the original data or dynamic items. If some dynamic items are not + requested, nor depended on by other requested items, they won't be computed. + So for example if a user simply wants to iterate over the text, the + time-consuming audio loading can be skipped. + + About the format: + Takes a dict of dicts as the collection of data points to read/wrangle. + The top level keys are data point IDs. + Each data point (example) dict should have the same keys, corresponding to + different items in that data point. + + Altogether the data collection could look like this: + + >>> data = { + ... "spk1utt1": { + ... "wav_file": "/path/to/spk1utt1.wav", + ... "text": "hello world", + ... "speaker": "spk1", + ... }, + ... "spk1utt2": { + ... "wav_file": "/path/to/spk1utt2.wav", + ... "text": "how are you world", + ... "speaker": "spk1", + ... }, + ... } + + NOTE + ---- + The top-level key, the data point id, is implicitly added as an item + in the data point, with the key "id" + + Each dynamic item is configured by three things: a key, a func, and a list + of argkeys. The key should be unique among all the items (dynamic or not) in + each data point. The func is any callable, and it returns the dynamic item's + value. The callable is called with the values of other items as specified + by the argkeys list (as positional args, passed in the order specified by + argkeys). + + The dynamic_items configuration could look like this: + + >>> import torch + >>> dynamic_items = [ + ... { + ... "func": lambda l: torch.Tensor(l), + ... "takes": ["wav_loaded"], + ... "provides": "wav", + ... }, + ... { + ... "func": lambda path: [ + ... ord(c) / 100 for c in path + ... ], # Fake "loading" + ... "takes": ["wav_file"], + ... "provides": "wav_loaded", + ... }, + ... { + ... "func": lambda t: t.split(), + ... "takes": ["text"], + ... "provides": "words", + ... }, + ... ] + + With these, different views of the data can be loaded: + + >>> from speechbrain.dataio.dataloader import SaveableDataLoader + >>> from speechbrain.dataio.batch import PaddedBatch + >>> dataset = DynamicItemDataset(data, dynamic_items) + >>> dataloader = SaveableDataLoader( + ... dataset, collate_fn=PaddedBatch, batch_size=2 + ... ) + >>> # First, create encoding for words: + >>> dataset.set_output_keys(["words"]) + >>> encoding = {} + >>> next_id = 1 + >>> for batch in dataloader: + ... for sent in batch.words: + ... for word in sent: + ... if word not in encoding: + ... encoding[word] = next_id + ... next_id += 1 + >>> # Next, add an encoded words_tensor dynamic item: + >>> dataset.add_dynamic_item( + ... func=lambda ws: torch.tensor( + ... [encoding[w] for w in ws], dtype=torch.long + ... ), + ... takes=["words"], + ... provides="words_encoded", + ... ) + >>> # Now we can get word and audio tensors: + >>> dataset.set_output_keys(["id", "wav", "words_encoded"]) + >>> batch = next(iter(dataloader)) + >>> batch.id + ['spk1utt1', 'spk1utt2'] + >>> batch.wav # +ELLIPSIS + PaddedData(data=tensor([[0.4700, 1.1200, ... + >>> batch.words_encoded + PaddedData(data=tensor([[1, 2, 0, 0], + [3, 4, 5, 2]]), lengths=tensor([0.5000, 1.0000])) + + Output keys can also be a map: + + >>> dataset.set_output_keys( + ... {"id": "id", "signal": "wav", "words": "words_encoded"} + ... ) + >>> batch = next(iter(dataloader)) + >>> batch.words + PaddedData(data=tensor([[1, 2, 0, 0], + [3, 4, 5, 2]]), lengths=tensor([0.5000, 1.0000])) + + + Arguments + --------- + data : dict + Dictionary containing single data points (e.g. utterances). + dynamic_items : list, optional + Configuration for the dynamic items produced when fetching an example. + List of DynamicItems or dicts with the format:: + func: # To be called + takes: # key or list of keys of args this takes + provides: key # key or list of keys that this provides + output_keys : dict, list, optional + List of keys (either directly available in data or dynamic items) + to include in the output dict when data points are fetched. + + If a dict is given; it is used to map internal keys to output keys. + From the output_keys dict key:value pairs the key appears outside, + and value is the internal key. + """ + + def __init__(self, data, dynamic_items=None, output_keys=None): + if dynamic_items is None: + dynamic_items = [] + if output_keys is None: + output_keys = [] + self.data = data + self.data_ids = list(self.data.keys()) + static_keys = list(self.data[self.data_ids[0]].keys()) + if "id" in static_keys: + raise ValueError("The key 'id' is reserved for the data point id.") + else: + static_keys.append("id") + self.pipeline = DataPipeline(static_keys, dynamic_items) + self.set_output_keys(output_keys) + + def __len__(self): + return len(self.data_ids) + + def __getitem__(self, index): + data_id = self.data_ids[index] + data_point = self.data[data_id] + return self.pipeline.compute_outputs({"id": data_id, **data_point}) + + def iterate_once(self, output_keys=None, progressbar=True): + """Iterates dataset once -- mainly used to warm up cache. + + Arguments + --------- + output_keys : Optional[list[str]] + List of keys to use for the iteration, potentially useful for + speeding up iterations when warming the cache is only needed on + a subset of the slow keys and other slow keys should be ignored. + progressbar : bool + Whether to add a tqdm progressbar for monitoring iteration time. + """ + + # If output_keys is None, just use current output mapping + output_keys = output_keys or self.pipeline.output_mapping + + # Iterate data but do nothing (e.g. to warm cache) + with self.output_keys_as(output_keys): + for item in tqdm.tqdm(self, disable=not progressbar): + pass + + def add_dynamic_item(self, func, takes=None, provides=None): + """Makes a new dynamic item available on the dataset. + + Two calling conventions. For DynamicItem objects, just use: + add_dynamic_item(dynamic_item). + But otherwise, should use: + add_dynamic_item(func, takes, provides). + + See `speechbrain.utils.data_pipeline`. + + Arguments + --------- + func : callable, DynamicItem + If a DynamicItem is given, adds that directly. Otherwise a + DynamicItem is created, and this specifies the callable to use. If + a generator function is given, then create a GeneratorDynamicItem. + Otherwise creates a normal DynamicItem. + takes : list, str + List of keys. When func is called, each key is resolved to + either an entry in the data or the output of another dynamic_item. + The func is then called with these as positional arguments, + in the same order as specified here. + A single arg can be given directly. + provides : str + Unique key or keys that this provides. + """ + self.pipeline.add_dynamic_item(func, takes, provides) + + def set_output_keys(self, keys): + """Use this to change the output keys. + + These are the keys that are actually evaluated when a data point + is fetched from the dataset. + + Arguments + --------- + keys : dict, list + List of keys (str) to produce in output. + + If a dict is given; it is used to map internal keys to output keys. + From the output_keys dict key:value pairs the key appears outside, + and value is the internal key. + """ + self.pipeline.set_output_keys(keys) + + @contextlib.contextmanager + def output_keys_as(self, keys): + """Context manager to temporarily set output keys. + + Arguments + --------- + keys : list + A set of output keys to use in the context. + + Example + ------- + >>> dataset = DynamicItemDataset( + ... {"a": {"x": 1, "y": 2}, "b": {"x": 3, "y": 4}}, + ... output_keys=["x"], + ... ) + >>> with dataset.output_keys_as(["y"]): + ... print(dataset[0]) + {'y': 2} + >>> print(dataset[0]) + {'x': 1} + + NOTE + ---- + Not thread-safe. While in this context manager, the output keys + are affected for any call. + + Yields + ------ + self + """ + saved_output = self.pipeline.output_mapping + self.pipeline.set_output_keys(keys) + yield self + self.pipeline.set_output_keys(saved_output) + + def filtered_sorted( + self, + key_min_value={}, + key_max_value={}, + key_test={}, + sort_key=None, + reverse=False, + select_n=None, + ): + """Get a filtered and/or sorted version of this, shares static data. + + The reason to implement these operations in the same method is that + computing some dynamic items may be expensive, and this way the + filtering and sorting steps don't need to compute the dynamic items + twice. + + Arguments + --------- + key_min_value : dict + Map from key (in data or in dynamic items) to limit, will only keep + data_point if data_point[key] >= limit + key_max_value : dict + Map from key (in data or in dynamic items) to limit, will only keep + data_point if data_point[key] <= limit + key_test : dict + Map from key (in data or in dynamic items) to func, will only keep + data_point if bool(func(data_point[key])) == True + sort_key : None, str + If not None, sort by data_point[sort_key]. Default is ascending + order. + reverse : bool + If True, sort in descending order. + select_n : None, int + If not None, only keep (at most) the first n filtered data_points. + The possible sorting is applied, but only on the first n data + points found. Meant for debugging. + + Returns + ------- + FilteredSortedDynamicItemDataset + Shares the static data, but has its own output keys and + dynamic items (initially deep copied from this, so they have the + same dynamic items available) + + NOTE + ---- + Temporarily changes the output keys! + """ + filtered_sorted_ids = self._filtered_sorted_ids( + key_min_value, key_max_value, key_test, sort_key, reverse, select_n + ) + return FilteredSortedDynamicItemDataset( + self, filtered_sorted_ids + ) # NOTE: defined below + + def _filtered_sorted_ids( + self, + key_min_value={}, + key_max_value={}, + key_test={}, + sort_key=None, + reverse=False, + select_n=None, + ): + """Returns a list of data ids, fulfilling the sorting and filtering.""" + + def combined_filter(computed): + """Applies filter.""" + for key, limit in key_min_value.items(): + # NOTE: docstring promises >= so using that. + # Mathematically could also use < for nicer syntax, but + # maybe with some super special weird edge case some one can + # depend on the >= operator + if computed[key] >= limit: + continue + return False + for key, limit in key_max_value.items(): + if computed[key] <= limit: + continue + return False + for key, func in key_test.items(): + if bool(func(computed[key])): + continue + return False + return True + + temp_keys = ( + set(key_min_value.keys()) + | set(key_max_value.keys()) + | set(key_test.keys()) + | set([] if sort_key is None else [sort_key]) + ) + filtered_ids = [] + with self.output_keys_as(temp_keys): + for i, data_id in enumerate(self.data_ids): + if select_n is not None and len(filtered_ids) == select_n: + break + data_point = self.data[data_id] + data_point["id"] = data_id + computed = self.pipeline.compute_outputs(data_point) + if combined_filter(computed): + if sort_key is not None: + # Add (main sorting index, current index, data_id) + # So that we maintain current sorting and don't compare + # data_id values ever. + filtered_ids.append((computed[sort_key], i, data_id)) + else: + filtered_ids.append(data_id) + if sort_key is not None: + filtered_sorted_ids = [ + tup[2] for tup in sorted(filtered_ids, reverse=reverse) + ] + else: + filtered_sorted_ids = filtered_ids + return filtered_sorted_ids + + def overfit_test(self, sample_count, total_count): + """Creates a subset of this dataset for an overfitting + test - repeating sample_count samples to create a repeating + dataset with a total of epoch_data_count samples + + Arguments + --------- + sample_count: int + the number of samples to select + total_count: int + the total data count + + Returns + ------- + dataset: FilteredSortedDynamicItemDataset + a dataset with a repeated subset + """ + num_repetitions = math.ceil(total_count / sample_count) + overfit_samples = self.data_ids[:sample_count] * num_repetitions + overfit_samples = overfit_samples[:total_count] + return FilteredSortedDynamicItemDataset(self, overfit_samples) + + def batch_shuffle(self, batch_size): + """Shuffles batches within a dataset. This is particularly + useful in combination with length sorting - to ensure + that the length variation within a batch is not very high, + but the batches themselves remain randomized + + Arguments + --------- + batch_size: int + the batch size + + Returns + ------- + dataset: FilteredSortedDynamicItemDataset + a shuffled dataset + """ + data_ids = batch_shuffle(self.data_ids, batch_size) + return FilteredSortedDynamicItemDataset(self, data_ids) + + @classmethod + def from_json( + cls, json_path, replacements={}, dynamic_items=[], output_keys=[] + ): + """Load a data prep JSON file and create a Dataset based on it.""" + data = load_data_json(json_path, replacements) + return cls(data, dynamic_items, output_keys) + + @classmethod + def from_csv( + cls, csv_path, replacements={}, dynamic_items=[], output_keys=[] + ): + """Load a data prep CSV file and create a Dataset based on it.""" + data = load_data_csv(csv_path, replacements) + return cls(data, dynamic_items, output_keys) + + @classmethod + def from_arrow_dataset( + cls, dataset, replacements={}, dynamic_items=[], output_keys=[] + ): + """Loading a prepared huggingface dataset""" + + # define an unbound method to generate pseudo keys + def keys(self): + "Returns the keys." + return [i for i in range(dataset.__len__())] + + # bind this method to arrow dataset + dataset.keys = MethodType(keys, dataset) + return cls(dataset, dynamic_items, output_keys) + + +class FilteredSortedDynamicItemDataset(DynamicItemDataset): + """Possibly filtered, possibly sorted DynamicItemDataset. + + Shares the static data (reference). + Has its own dynamic_items and output_keys (deepcopy). + """ + + def __init__(self, from_dataset, data_ids): + self.data = from_dataset.data + self.data_ids = data_ids + self.pipeline = copy.deepcopy(from_dataset.pipeline) + + @classmethod + def from_json( + cls, json_path, replacements={}, dynamic_items=None, output_keys=None + ): + raise TypeError("Cannot create SubsetDynamicItemDataset directly!") + + @classmethod + def from_csv( + cls, csv_path, replacements={}, dynamic_items=None, output_keys=None + ): + raise TypeError("Cannot create SubsetDynamicItemDataset directly!") + + +def add_dynamic_item(datasets, func, takes=None, provides=None): + """Helper for adding the same item to multiple datasets.""" + for dataset in datasets: + dataset.add_dynamic_item(func, takes, provides) + + +def set_output_keys(datasets, output_keys): + """Helper for setting the same item to multiple datasets.""" + for dataset in datasets: + dataset.set_output_keys(output_keys) + + +def apply_overfit_test( + overfit_test, + overfit_test_sample_count, + overfit_test_epoch_data_count, + dataset, +): + """Applies the overfit test to the specified dataset, + as configured in the hyperparameters file + + Arguments + --------- + + overfit_test: bool + when True the overfitting test is performed + overfit_test_sample_count: int + number of samples for the overfitting test + overfit_test_epoch_data_count: int + number of epochs for the overfitting test + + dataset: DynamicItemDataset + the dataset + + Returns + ------- + dataset: DynamicItemDataset + the dataset, with the overfit test apply + """ + if overfit_test: + sample_count = overfit_test_sample_count + epoch_data_count = overfit_test_epoch_data_count + dataset = dataset.overfit_test(sample_count, epoch_data_count) + return dataset diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/encoder.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/encoder.py new file mode 100644 index 000000000..286e70f4a --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/encoder.py @@ -0,0 +1,1216 @@ +"""Encoding categorical data as integers + +Authors + * Samuele Cornell 2020 + * Aku Rouhe 2020 +""" + +import ast +import collections +import itertools + +import torch + +import speechbrain as sb +from speechbrain.utils.checkpoints import ( + mark_as_loader, + mark_as_saver, + register_checkpoint_hooks, +) +from speechbrain.utils.logger import get_logger + +logger = get_logger(__name__) + +# NOTE: Changing these does NOT change the defaults in the classes. +# Consider these read-only. +DEFAULT_UNK = "" +DEFAULT_BOS = "" +DEFAULT_EOS = "" +DEFAULT_BLANK = "" + + +@register_checkpoint_hooks +class CategoricalEncoder: + """Encode labels of a discrete set. + + Used for encoding, e.g., speaker identities in speaker recognition. + Given a collection of hashables (e.g a strings) it encodes + every unique item to an integer value: ["spk0", "spk1"] --> [0, 1] + Internally the correspondence between each label to its index is handled by + two dictionaries: lab2ind and ind2lab. + + The label integer encoding can be generated automatically from a SpeechBrain + DynamicItemDataset by specifying the desired entry (e.g., spkid) in the annotation + and calling update_from_didataset method: + + >>> from speechbrain.dataio.encoder import CategoricalEncoder + >>> from speechbrain.dataio.dataset import DynamicItemDataset + >>> dataset = { + ... "ex_{}".format(x): {"spkid": "spk{}".format(x)} for x in range(20) + ... } + >>> dataset = DynamicItemDataset(dataset) + >>> encoder = CategoricalEncoder() + >>> encoder.update_from_didataset(dataset, "spkid") + >>> assert len(encoder) == len( + ... dataset + ... ) # different speaker for each utterance + + However can also be updated from an iterable: + + >>> from speechbrain.dataio.encoder import CategoricalEncoder + >>> from speechbrain.dataio.dataset import DynamicItemDataset + >>> dataset = ["spk{}".format(x) for x in range(20)] + >>> encoder = CategoricalEncoder() + >>> encoder.update_from_iterable(dataset) + >>> assert len(encoder) == len(dataset) + + Note + ---- + In both methods it can be specified it the single element in the iterable + or in the dataset should be treated as a sequence or not (default False). + If it is a sequence each element in the sequence will be encoded. + + + >>> from speechbrain.dataio.encoder import CategoricalEncoder + >>> from speechbrain.dataio.dataset import DynamicItemDataset + >>> dataset = [[x + 1, x + 2] for x in range(20)] + >>> encoder = CategoricalEncoder() + >>> encoder.ignore_len() + >>> encoder.update_from_iterable(dataset, sequence_input=True) + >>> assert len(encoder) == 21 # there are only 21 unique elements 1-21 + + This class offers 4 different methods to explicitly add a label in the internal + dicts: add_label, ensure_label, insert_label, enforce_label. + add_label and insert_label will raise an error if it is already present in the + internal dicts. insert_label, enforce_label allow also to specify the integer value + to which the desired label is encoded. + + Encoding can be performed using 4 different methods: + encode_label, encode_sequence, encode_label_torch and encode_sequence_torch. + encode_label operate on single labels and simply returns the corresponding + integer encoding: + + >>> from speechbrain.dataio.encoder import CategoricalEncoder + >>> from speechbrain.dataio.dataset import DynamicItemDataset + >>> dataset = ["spk{}".format(x) for x in range(20)] + >>> encoder.update_from_iterable(dataset) + >>> + 22 + >>> + encode_sequence on sequences of labels: + >>> encoder.encode_sequence(["spk1", "spk19"]) + [22, 40] + >>> + encode_label_torch and encode_sequence_torch return torch tensors + >>> encoder.encode_sequence_torch(["spk1", "spk19"]) + tensor([22, 40]) + >>> + Decoding can be performed using decode_torch and decode_ndim methods. + >>> encoded = encoder.encode_sequence_torch(["spk1", "spk19"]) + >>> encoder.decode_torch(encoded) + ['spk1', 'spk19'] + >>> + decode_ndim is used for multidimensional list or pytorch tensors + >>> encoded = encoded.unsqueeze(0).repeat(3, 1) + >>> encoder.decode_torch(encoded) + [['spk1', 'spk19'], ['spk1', 'spk19'], ['spk1', 'spk19']] + >>> + + In some applications, it can happen that during testing a label which has not + been encountered during training is encountered. To handle this out-of-vocabulary + problem add_unk can be used. Every out-of-vocab label is mapped to this special + label and its corresponding integer encoding. + + >>> import torch + >>> try: + ... encoder.encode_label("spk42") + ... except KeyError: + ... print("spk42 is not in the encoder this raises an error!") + spk42 is not in the encoder this raises an error! + >>> encoder.add_unk() + 41 + >>> encoder.encode_label("spk42") + 41 + >>> + returns the encoding + + This class offers also methods to save and load the internal mappings between + labels and tokens using: save and load methods as well as load_or_create. + """ + + VALUE_SEPARATOR = " => " + EXTRAS_SEPARATOR = "================\n" + + def __init__(self, starting_index=0, **special_labels): + self.lab2ind = {} + self.ind2lab = {} + self.starting_index = starting_index + # NOTE: unk_label is not necessarily set at all! + # This is because None is a suitable value for unk. + # So the test is: hasattr(self, "unk_label") + # rather than self.unk_label is not None + self.handle_special_labels(special_labels) + + def handle_special_labels(self, special_labels): + """Handles special labels such as unk_label.""" + if "unk_label" in special_labels: + self.add_unk(special_labels["unk_label"]) + + def __len__(self): + return len(self.lab2ind) + + @classmethod + def from_saved(cls, path): + """Recreate a previously saved encoder directly""" + obj = cls() + obj.load(path) + return obj + + def update_from_iterable(self, iterable, sequence_input=False): + """Update from iterator + + Arguments + --------- + iterable : iterable + Input sequence on which to operate. + sequence_input : bool + Whether iterable yields sequences of labels or individual labels + directly. (default False) + """ + if sequence_input: + label_iterator = itertools.chain.from_iterable(iterable) + else: + label_iterator = iter(iterable) + for label in label_iterator: + self.ensure_label(label) + + def update_from_didataset( + self, didataset, output_key, sequence_input=False + ): + """Update from DynamicItemDataset. + + Arguments + --------- + didataset : DynamicItemDataset + Dataset on which to operate. + output_key : str + Key in the dataset (in data or a dynamic item) to encode. + sequence_input : bool + Whether the data yielded with the specified key consists of + sequences of labels or individual labels directly. + """ + with didataset.output_keys_as([output_key]): + self.update_from_iterable( + (data_point[output_key] for data_point in didataset), + sequence_input=sequence_input, + ) + + def limited_labelset_from_iterable( + self, iterable, sequence_input=False, n_most_common=None, min_count=1 + ): + """Produce label mapping from iterable based on label counts + + Used to limit label set size. + + Arguments + --------- + iterable : iterable + Input sequence on which to operate. + sequence_input : bool + Whether iterable yields sequences of labels or individual labels + directly. False by default. + n_most_common : int, None + Take at most this many labels as the label set, keeping the most + common ones. If None (as by default), take all. + min_count : int + Don't take labels if they appear less than this many times. + + Returns + ------- + collections.Counter + The counts of the different labels (unfiltered). + """ + if self.lab2ind: + clsname = self.__class__.__name__ + logger.info( + f"Limited_labelset_from_iterable called, " + f"but {clsname} is not empty. " + "The new labels will be added, i.e. won't overwrite. " + "This is normal if there is e.g. an unk label already." + ) + if sequence_input: + label_iterator = itertools.chain.from_iterable(iterable) + else: + label_iterator = iter(iterable) + counts = collections.Counter(label_iterator) + for label, count in counts.most_common(n_most_common): + if count < min_count: + # .most_common() produces counts in descending order, + # so no more labels can be found + break + self.add_label(label) + return counts + + def load_or_create( + self, + path, + from_iterables=[], + from_didatasets=[], + sequence_input=False, + output_key=None, + special_labels={}, + ): + """Convenient syntax for creating the encoder conditionally + + This pattern would be repeated in so many experiments that + we decided to add a convenient shortcut for it here. The + current version is multi-gpu (DDP) safe. + """ + try: + if sb.utils.distributed.if_main_process(): + if not self.load_if_possible(path): + for iterable in from_iterables: + self.update_from_iterable(iterable, sequence_input) + for didataset in from_didatasets: + if output_key is None: + raise ValueError( + "Provide an output_key for DynamicItemDataset" + ) + self.update_from_didataset( + didataset, output_key, sequence_input + ) + self.handle_special_labels(special_labels) + self.save(path) + finally: + sb.utils.distributed.ddp_barrier() + self.load(path) + + def add_label(self, label): + """Add new label to the encoder, at the next free position. + + Arguments + --------- + label : hashable + Most often labels are str, but anything that can act as dict key is + supported. Note that default save/load only supports Python + literals. + + Returns + ------- + int + The index that was used to encode this label. + """ + if label in self.lab2ind: + clsname = self.__class__.__name__ + raise KeyError(f"Label already present in {clsname}") + index = self._next_index() + self.lab2ind[label] = index + self.ind2lab[index] = label + return index + + def ensure_label(self, label): + """Add a label if it is not already present. + + Arguments + --------- + label : hashable + Most often labels are str, but anything that can act as dict key is + supported. Note that default save/load only supports Python + literals. + + Returns + ------- + int + The index that was used to encode this label. + """ + if label in self.lab2ind: + return self.lab2ind[label] + else: + return self.add_label(label) + + def insert_label(self, label, index): + """Add a new label, forcing its index to a specific value. + + If a label already has the specified index, it is moved to the end + of the mapping. + + Arguments + --------- + label : hashable + Most often labels are str, but anything that can act as dict key is + supported. Note that default save/load only supports Python + literals. + index : int + The specific index to use. + """ + if label in self.lab2ind: + clsname = self.__class__.__name__ + raise KeyError(f"Label already present in {clsname}") + else: + self.enforce_label(label, index) + + def enforce_label(self, label, index): + """Make sure label is present and encoded to a particular index. + + If the label is present but encoded to some other index, it is + moved to the given index. + + If there is already another label at the + given index, that label is moved to the next free position. + """ + index = int(index) + if label in self.lab2ind: + if index == self.lab2ind[label]: + return + else: + # Delete old index mapping. Everything else gets overwritten. + del self.ind2lab[self.lab2ind[label]] + # Move other label out of the way: + if index in self.ind2lab: + saved_label = self.ind2lab[index] + moving_other = True + else: + moving_other = False + # Ready to push the new index. + self.lab2ind[label] = index + self.ind2lab[index] = label + # And finally put the moved index in new spot. + if moving_other: + logger.info( + f"Moving label {repr(saved_label)} from index " + f"{index}, because {repr(label)} was put at its place." + ) + new_index = self._next_index() + self.lab2ind[saved_label] = new_index + self.ind2lab[new_index] = saved_label + + def add_unk(self, unk_label=DEFAULT_UNK): + """Add label for unknown tokens (out-of-vocab). + + When asked to encode unknown labels, they can be mapped to this. + + Arguments + --------- + unk_label : hashable, optional + Most often labels are str, but anything that can act as dict key is + supported. Note that default save/load only supports Python + literals. Default: . This can be None, as well! + + Returns + ------- + int + The index that was used to encode this. + """ + self.unk_label = unk_label + return self.add_label(unk_label) + + def _next_index(self): + """The index to use for the next new label""" + index = self.starting_index + while index in self.ind2lab: + index += 1 + return index + + def is_continuous(self): + """Check that the set of indices doesn't have gaps + + For example: + If starting index = 1 + Continuous: [1,2,3,4] + Continuous: [0,1,2] + Non-continuous: [2,3,4] + Non-continuous: [1,2,4] + + Returns + ------- + bool + True if continuous. + """ + # Because of Python indexing this also handles the special cases + # of 0 or 1 labels. + indices = sorted(self.ind2lab.keys()) + return self.starting_index in indices and all( + j - i == 1 for i, j in zip(indices[:-1], indices[1:]) + ) + + def encode_label(self, label, allow_unk=True): + """Encode label to int + + Arguments + --------- + label : hashable + Label to encode, must exist in the mapping. + allow_unk : bool + If given, that label is not in the label set + AND unk_label has been added with add_unk(), + allows encoding to unk_label's index. + + Returns + ------- + int + Corresponding encoded int value. + """ + self._assert_len() + try: + return self.lab2ind[label] + except KeyError: + if hasattr(self, "unk_label") and allow_unk: + return self.lab2ind[self.unk_label] + elif hasattr(self, "unk_label") and not allow_unk: + raise KeyError( + f"Unknown label {label}, and explicitly " + "disallowed the use of the existing unk-label" + ) + elif not hasattr(self, "unk_label") and allow_unk: + raise KeyError( + f"Cannot encode unknown label {label}. " + "You have not called add_unk() to add a special " + "unk-label for unknown labels." + ) + else: + raise KeyError( + f"Couldn't and wouldn't encode unknown label {label}." + ) + + def encode_label_torch(self, label, allow_unk=True): + """Encode label to torch.LongTensor. + + Arguments + --------- + label : hashable + Label to encode, must exist in the mapping. + allow_unk : bool + If given, that label is not in the label set + AND unk_label has been added with add_unk(), + allows encoding to unk_label's index. + + Returns + ------- + torch.LongTensor + Corresponding encoded int value. + Tensor shape [1]. + """ + return torch.LongTensor([self.encode_label(label, allow_unk)]) + + def encode_sequence(self, sequence, allow_unk=True): + """Encode a sequence of labels to list + + Arguments + --------- + sequence : iterable + Labels to encode, must exist in the mapping. + allow_unk : bool + If given, that label is not in the label set + AND unk_label has been added with add_unk(), + allows encoding to unk_label's index. + + Returns + ------- + list + Corresponding integer labels. + """ + self._assert_len() + return [self.encode_label(label, allow_unk) for label in sequence] + + def encode_sequence_torch(self, sequence, allow_unk=True): + """Encode a sequence of labels to torch.LongTensor + + Arguments + --------- + sequence : iterable + Labels to encode, must exist in the mapping. + allow_unk : bool + If given, that label is not in the label set + AND unk_label has been added with add_unk(), + allows encoding to unk_label's index. + + Returns + ------- + torch.LongTensor + Corresponding integer labels. + Tensor shape [len(sequence)]. + """ + return torch.LongTensor( + [self.encode_label(label, allow_unk) for label in sequence] + ) + + def decode_torch(self, x): + """Decodes an arbitrarily nested torch.Tensor to a list of labels. + + Provided separately because Torch provides clearer introspection, + and so doesn't require try-except. + + Arguments + --------- + x : torch.Tensor + Torch tensor of some integer dtype (Long, int) and any shape to + decode. + + Returns + ------- + list + list of original labels + """ + self._assert_len() + decoded = [] + # Recursively operates on the different dimensions. + if x.ndim == 1: # Last dimension! + for element in x: + decoded.append(self.ind2lab[int(element)]) + else: + for subtensor in x: + decoded.append(self.decode_torch(subtensor)) + return decoded + + def decode_ndim(self, x): + """Decodes an arbitrarily nested iterable to a list of labels. + + This works for essentially any pythonic iterable (including torch), and + also single elements. + + Arguments + --------- + x : Any + Python list or other iterable or torch.Tensor or a single integer element + + Returns + ------- + list, Any + ndim list of original labels, or if input was single element, + output will be, too. + """ + self._assert_len() + # Recursively operates on the different dimensions. + try: + decoded = [] + for subtensor in x: + decoded.append(self.decode_ndim(subtensor)) + return decoded + except TypeError: # Not an iterable, bottom level! + return self.ind2lab[int(x)] + + @mark_as_saver + def save(self, path): + """Save the categorical encoding for later use and recovery + + Saving uses a Python literal format, which supports things like + tuple labels, but is considered safe to load (unlike e.g. pickle). + + Arguments + --------- + path : str, Path + Where to save. Will overwrite. + """ + extras = self._get_extras() + self._save_literal(path, self.lab2ind, extras) + + def load(self, path): + """Loads from the given path. + + CategoricalEncoder uses a Python literal format, which supports things + like tuple labels, but is considered safe to load (unlike e.g. pickle). + + Arguments + --------- + path : str, Path + Where to load from. + """ + if self.lab2ind: + clsname = self.__class__.__name__ + logger.info( + f"Load called, but {clsname} is not empty. " + "Loaded data will overwrite everything. " + "This is normal if there is e.g. an unk label defined at init." + ) + lab2ind, ind2lab, extras = self._load_literal(path) + self.lab2ind = lab2ind + self.ind2lab = ind2lab + self._set_extras(extras) + # If we're here, load was a success! + logger.debug(f"Loaded categorical encoding from {path}") + + @mark_as_loader + def load_if_possible(self, path, end_of_epoch=False): + """Loads if possible, returns a bool indicating if loaded or not. + + Arguments + --------- + path : str, Path + Where to load from. + end_of_epoch : bool + Whether the checkpoint was end-of-epoch or not. + + Returns + ------- + bool : + If load was successful. + + Example + ------- + >>> encoding_file = getfixture("tmpdir") / "encoding.txt" + >>> encoder = CategoricalEncoder() + >>> # The idea is in an experiment script to have something like this: + >>> if not encoder.load_if_possible(encoding_file): + ... encoder.update_from_iterable("abcd") + ... encoder.save(encoding_file) + >>> # So the first time you run the experiment, the encoding is created. + >>> # However, later, the encoding exists: + >>> encoder = CategoricalEncoder() + >>> encoder.expect_len(4) + >>> if not encoder.load_if_possible(encoding_file): + ... assert False # We won't get here! + >>> encoder.decode_ndim(range(4)) + ['a', 'b', 'c', 'd'] + """ + del end_of_epoch # Unused here. + + try: + self.load(path) + except FileNotFoundError: + logger.debug( + f"Would load categorical encoding from {path}, " + "but file doesn't exist yet." + ) + return False + except (ValueError, SyntaxError): + logger.debug( + f"Would load categorical encoding from {path}, " + "and file existed but seems to be corrupted or otherwise couldn't load." + ) + return False + return True # If here, all good + + def expect_len(self, expected_len): + """Specify the expected category count. If the category count observed + during encoding/decoding does NOT match this, an error will be raised. + + This can prove useful to detect bugs in scenarios where the encoder is + dynamically built using a dataset, but downstream code expects a + specific category count (and may silently break otherwise). + + This can be called anytime and the category count check will only be + performed during an actual encoding/decoding task. + + Arguments + --------- + expected_len : int + The expected final category count, i.e. `len(encoder)`. + + Example + ------- + >>> encoder = CategoricalEncoder() + >>> encoder.update_from_iterable("abcd") + >>> encoder.expect_len(3) + >>> encoder.encode_label("a") + Traceback (most recent call last): + ... + RuntimeError: .expect_len(3) was called, but 4 categories found + >>> encoder.expect_len(4) + >>> encoder.encode_label("a") + 0 + """ + self.expected_len = expected_len + + def ignore_len(self): + """Specifies that category count shall be ignored at encoding/decoding + time. + + Effectively inhibits the ".expect_len was never called" warning. + Prefer :py:meth:`~CategoricalEncoder.expect_len` when the category count + is known.""" + self.expected_len = None + + def _assert_len(self): + """If `expect_len` was called, then check if len(self) matches the + expected value. If it does not, raise a RuntimeError. + If neither `expect_len` or `ignore_len` were ever called, warn once.""" + if hasattr(self, "expected_len"): + # skip when ignore_len() was called + if self.expected_len is None: + return + + real_len = len(self) + + if real_len != self.expected_len: + raise RuntimeError( + f".expect_len({self.expected_len}) was called, " + f"but {real_len} categories found" + ) + else: + logger.warning_once( + f"{self.__class__.__name__}.expect_len was never called: " + f"assuming category count of {len(self)} to be correct! " + "Sanity check your encoder using `.expect_len`. " + "Ensure that downstream code also uses the correct size. " + "If you are sure this does not apply to you, use `.ignore_len`." + ) + self.ignore_len() + return + + def _get_extras(self): + """Override this to provide any additional things to save + + Call super()._get_extras() to get the base extras + """ + extras = {"starting_index": self.starting_index} + if hasattr(self, "unk_label"): + extras["unk_label"] = self.unk_label + return extras + + def _set_extras(self, extras): + """Override this to e.g. load any extras needed + + Call super()._set_extras(extras) to set the base extras + """ + if "unk_label" in extras: + self.unk_label = extras["unk_label"] + self.starting_index = extras["starting_index"] + + @staticmethod + def _save_literal(path, lab2ind, extras): + """Save which is compatible with _load_literal""" + with open(path, "w", encoding="utf-8") as f: + for label, ind in lab2ind.items(): + f.write( + repr(label) + + CategoricalEncoder.VALUE_SEPARATOR + + str(ind) + + "\n" + ) + f.write(CategoricalEncoder.EXTRAS_SEPARATOR) + for key, value in extras.items(): + f.write( + repr(key) + + CategoricalEncoder.VALUE_SEPARATOR + + repr(value) + + "\n" + ) + f.flush() + + @staticmethod + def _load_literal(path): + """Load which supports Python literals as keys. + + This is considered safe for user input, as well (unlike e.g. pickle). + """ + lab2ind = {} + ind2lab = {} + extras = {} + with open(path, encoding="utf-8") as f: + # Load the label to index mapping (until EXTRAS_SEPARATOR) + for line in f: + if line == CategoricalEncoder.EXTRAS_SEPARATOR: + break + literal, ind = line.strip().split( + CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1 + ) + ind = int(ind) + label = ast.literal_eval(literal) + lab2ind[label] = ind + ind2lab[ind] = label + # Load the extras: + for line in f: + literal_key, literal_value = line.strip().split( + CategoricalEncoder.VALUE_SEPARATOR, maxsplit=1 + ) + key = ast.literal_eval(literal_key) + value = ast.literal_eval(literal_value) + extras[key] = value + return lab2ind, ind2lab, extras + + +class TextEncoder(CategoricalEncoder): + """CategoricalEncoder subclass which offers specific methods for encoding text and handle + special tokens for training of sequence to sequence models. + In detail, aside special token already present in CategoricalEncoder + for handling out-of-vocab tokens here special methods to handle + beginning of sequence and tokens are defined. + + Note: update_from_iterable and update_from_didataset here have as default + sequence_input=True because it is assumed that this encoder is used on + iterables of strings: e.g. + + >>> from speechbrain.dataio.encoder import TextEncoder + >>> dataset = [["encode", "this", "textencoder"], ["foo", "bar"]] + >>> encoder = TextEncoder() + >>> encoder.update_from_iterable(dataset) + >>> encoder.expect_len(5) + >>> encoder.encode_label("this") + 1 + >>> encoder.add_unk() + 5 + >>> encoder.expect_len(6) + >>> encoder.encode_sequence(["this", "out-of-vocab"]) + [1, 5] + >>> + + Two methods can be used to add and to the internal dicts: + insert_bos_eos, add_bos_eos. + + >>> encoder.add_bos_eos() + >>> encoder.expect_len(8) + >>> encoder.lab2ind[encoder.eos_label] + 7 + >>> + add_bos_eos adds the special tokens at the end of the dict indexes + >>> encoder = TextEncoder() + >>> encoder.update_from_iterable(dataset) + >>> encoder.insert_bos_eos(bos_index=0, eos_index=1) + >>> encoder.expect_len(7) + >>> encoder.lab2ind[encoder.eos_label] + 1 + >>> + insert_bos_eos allows to specify whose index will correspond to each of them. + Note that you can also specify the same integer encoding for both. + + Four methods can be used to prepend and append . + prepend_bos_label and append_eos_label add respectively the and + string tokens to the input sequence + + >>> words = ["foo", "bar"] + >>> encoder.prepend_bos_label(words) + ['', 'foo', 'bar'] + >>> encoder.append_eos_label(words) + ['foo', 'bar', ''] + + prepend_bos_index and append_eos_index add respectively the and + indexes to the input encoded sequence. + + >>> words = ["foo", "bar"] + >>> encoded = encoder.encode_sequence(words) + >>> encoder.prepend_bos_index(encoded) + [0, 3, 4] + >>> encoder.append_eos_index(encoded) + [3, 4, 1] + + """ + + def handle_special_labels(self, special_labels): + """Handles special labels such as bos and eos.""" + super().handle_special_labels(special_labels) + # NOTE: bos_label and eos_label are not necessarily set at all! + # This is because None is a suitable value. + # So the test is: hasattr(self, "bos_label") + # rather than self.bos_label is not None + # Same thing with unk, see base class. + if "bos_label" in special_labels and "eos_label" in special_labels: + self.insert_bos_eos( + bos_label="", + eos_label="", + bos_index=special_labels["bos_label"], + eos_index=special_labels["eos_label"], + ) + elif "bos_label" in special_labels or "eos_label" in special_labels: + raise TypeError("Only BOS or EOS specified. Need both for init.") + + def update_from_iterable(self, iterable, sequence_input=True): + """Change default for sequence_input to True.""" + return super().update_from_iterable(iterable, sequence_input) + + def update_from_didataset(self, didataset, output_key, sequence_input=True): + """Change default for sequence_input to True.""" + return super().update_from_didataset( + didataset, output_key, sequence_input + ) + + def limited_labelset_from_iterable( + self, iterable, sequence_input=True, n_most_common=None, min_count=1 + ): + """Change default for sequence_input to True.""" + return super().limited_labelset_from_iterable( + iterable, sequence_input=True, n_most_common=None, min_count=1 + ) + + def add_bos_eos( + self, + bos_label=DEFAULT_BOS, + eos_label=DEFAULT_EOS, + ): + """Add sentence boundary markers in the label set. + + If the beginning-of-sentence and end-of-sentence markers + are the same, will just use one sentence-boundary label. + + This method adds to the end of the index, rather than at the beginning, + like insert_bos_eos. + + Arguments + --------- + bos_label : hashable + Beginning-of-sentence label, any label. + eos_label : hashable + End-of-sentence label, any label. If set to the same label as + bos_label, will just use one sentence-boundary label. + """ + if bos_label == eos_label: + logger.debug( + "BOS and EOS labels are the same so using just one sentence " + "boundary label" + ) + self.add_label(bos_label) + else: + self.add_label(bos_label) + self.add_label(eos_label) + self.bos_label = bos_label + self.eos_label = eos_label + + def insert_bos_eos( + self, + bos_label=DEFAULT_BOS, + eos_label=DEFAULT_EOS, + bos_index=0, + eos_index=None, + ): + """Insert sentence boundary markers in the label set. + + If the beginning-of-sentence and end-of-sentence markers + are the same, will just use one sentence-boundary label. + + Arguments + --------- + bos_label : hashable + Beginning-of-sentence label, any label + eos_label : hashable + End-of-sentence label, any label. If set to the same label as + bos_label, will just use one sentence-boundary label. + bos_index : int + Where to insert bos_label. eos_index = bos_index + 1 + eos_index : optional, int + Where to insert eos_label. Default: eos_index = bos_index + 1 + """ + if bos_label == eos_label: + logger.debug( + "BOS and EOS labels are the same so using just one sentence " + "boundary label" + ) + self.insert_label(bos_label, bos_index) + else: + self.insert_label(bos_label, bos_index) + if eos_index is None: + logger.debug("EOS label not specified, using BOS label + 1") + self.insert_label(eos_label, bos_index + 1) + else: + self.insert_label(eos_label, eos_index) + self.bos_label = bos_label + self.eos_label = eos_label + + def get_bos_index(self): + """Returns the index to which blank encodes""" + if not hasattr(self, "bos_label"): + raise RuntimeError("BOS label is not set!") + return self.encode_label(self.bos_label) + + def get_eos_index(self): + """Returns the index to which blank encodes""" + if not hasattr(self, "eos_label"): + raise RuntimeError("EOS label is not set!") + return self.encode_label(self.eos_label) + + def prepend_bos_label(self, x): + """Returns a list version of x, with BOS prepended""" + if not hasattr(self, "bos_label"): + raise KeyError("BOS label has not been added to label set!") + return [self.bos_label] + list(x) + + def prepend_bos_index(self, x): + """Returns a list version of x, with BOS index prepended. + If the input is a tensor, a tensor is returned.""" + if not hasattr(self, "bos_label"): + raise KeyError("BOS label has not been added to label set!") + if torch.is_tensor(x): + bos_ind = torch.Tensor([self.lab2ind[self.bos_label]]) + return torch.cat([bos_ind, x]) + return [self.lab2ind[self.bos_label]] + list(x) + + def append_eos_label(self, x): + """Returns a list version of x, with EOS appended.""" + if not hasattr(self, "eos_label"): + raise KeyError("EOS label has not been added to label set!") + return list(x) + [self.eos_label] + + def append_eos_index(self, x): + """Returns a list version of x, with EOS index appended. + If the input is a tensor, a tensor is returned.""" + if not hasattr(self, "eos_label"): + raise KeyError("EOS label has not been added to label set!") + if torch.is_tensor(x): + eos_ind = torch.Tensor([self.lab2ind[self.eos_label]]) + return torch.cat([x, eos_ind]) + return list(x) + [self.lab2ind[self.eos_label]] + + def _get_extras(self): + extras = super()._get_extras() + if hasattr(self, "bos_label"): + extras["bos_label"] = self.bos_label + if hasattr(self, "eos_label"): + extras["eos_label"] = self.eos_label + return extras + + def _set_extras(self, extras): + super()._set_extras(extras) + if "bos_label" in extras: + self.bos_label = extras["bos_label"] + if "eos_label" in extras: + self.eos_label = extras["eos_label"] + + +class CTCTextEncoder(TextEncoder): + """Subclass of TextEncoder which also provides methods to handle CTC blank token. + + add_blank and insert_blank can be used to add special token to the encoder + state. + + >>> from speechbrain.dataio.encoder import CTCTextEncoder + >>> chars = ["a", "b", "c", "d"] + >>> encoder = CTCTextEncoder() + >>> encoder.update_from_iterable(chars) + >>> encoder.add_blank() + >>> encoder.expect_len(5) + >>> encoder.encode_sequence(chars) + [0, 1, 2, 3] + >>> encoder.get_blank_index() + 4 + >>> encoder.decode_ndim([0, 1, 2, 3, 4]) + ['a', 'b', 'c', 'd', ''] + + collapse_labels and collapse_indices_ndim can be used to apply CTC collapsing + rules: + >>> encoder.collapse_labels(["a", "a", "b", "c", "d"]) + ['a', 'b', 'c', 'd'] + >>> encoder.collapse_indices_ndim([4, 4, 0, 1, 2, 3, 4, 4]) # 4 is + [0, 1, 2, 3] + """ + + def handle_special_labels(self, special_labels): + """Handles special labels such as blanks.""" + # super().handle_special_labels(special_labels) + # NOTE: blank_label is not necessarily set at all! + # This is because None is a suitable value. + # So the test is: hasattr(self, "blank_label") + # rather than self.blank_label is not None + # Same thing with unk, see base class. + if "blank_label" in special_labels: + self.insert_blank(index=special_labels["blank_label"]) + + super().handle_special_labels(special_labels) + + def add_blank(self, blank_label=DEFAULT_BLANK): + """Add blank symbol to labelset.""" + self.add_label(blank_label) + self.blank_label = blank_label + + def insert_blank(self, blank_label=DEFAULT_BLANK, index=0): + """Insert blank symbol at a given labelset.""" + self.insert_label(blank_label, index) + self.blank_label = blank_label + + def get_blank_index(self): + """Returns the index to which blank encodes.""" + if not hasattr(self, "blank_label"): + raise RuntimeError("Blank label is not set!") + return self.encode_label(self.blank_label) + + def collapse_labels(self, x, merge_repeats=True): + """Applies the CTC collapsing rules on one label sequence. + + Arguments + --------- + x : iterable + Label sequence on which to operate. + merge_repeats : bool + Whether to merge repeated labels before removing blanks. + In the basic CTC label topology, repeated labels are merged. + However, in RNN-T, they are not. + + Returns + ------- + list + List of labels with collapsing rules applied. + """ + # This cannot work on arbitrary "ndim", because strings can be + # infinitely iterated. Iterating "a" produces "a" over and over again. + if not hasattr(self, "blank_label"): + raise KeyError("Blank label has not been added") + if merge_repeats: + return [ + label + for i, label in enumerate(x) + if (i == 0 or label != x[i - 1]) and label != self.blank_label + ] + else: + return [label for label in x if label != self.blank_label] + + def collapse_indices_ndim(self, x, merge_repeats=True): + """Applies the CTC collapsing rules on arbitrarily label sequence. + + Arguments + --------- + x : iterable + Label sequence on which to operate. + merge_repeats : bool + Whether to merge repeated labels before removing blanks. + In the basic CTC label topology, repeated labels are merged. + However, in RNN-T, they are not. + + Returns + ------- + list + List of labels with collapsing rules applied. + """ + if not hasattr(self, "blank_label"): + raise KeyError("Blank label has not been added") + # Recursively operates on the different dimensions. + collapsed = [] + for subtensor in x: + try: + collapsed.append( + self.collapse_indices_ndim(subtensor, merge_repeats) + ) + except TypeError: # Not an iterable at next level! + # So we should rather operate on this dimension. + break + else: # For-else: only enter else if NO break. + return collapsed + # We get here if we DID break: + blank_index = self.lab2ind[self.blank_label] + if merge_repeats: + return [ + index + for i, index in enumerate(x) + if (i == 0 or index != x[i - 1]) and index != blank_index + ] + else: + return [index for index in x if index != blank_index] + + def _get_extras(self): + extras = super()._get_extras() + if hasattr(self, "blank_label"): + extras["blank_label"] = self.blank_label + return extras + + def _set_extras(self, extras): + super()._set_extras(extras) + if "blank_label" in extras: + self.blank_label = extras["blank_label"] + + +def load_text_encoder_tokens(model_path): + """Loads the encoder tokens from a pretrained model. + + This method is useful when you used with a pretrained HF model. + It will load the tokens in the yaml and then you will be able + to instantiate any CTCBaseSearcher directly in the YAML file. + + Arguments + --------- + model_path : str, Path + Path to the pretrained model. + + Returns + ------- + list + List of tokens. + """ + label_encoder = TextEncoder() + label_encoder.load(model_path) + return list(label_encoder.lab2ind.keys()) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/iterators.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/iterators.py new file mode 100644 index 000000000..19515329f --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/iterators.py @@ -0,0 +1,235 @@ +"""Webdataset compatible iterators + +Authors: + * Aku Rouhe 2021 +""" + +import bisect +import random +from dataclasses import dataclass, field +from functools import partial +from typing import Any + +from speechbrain.dataio.batch import PaddedBatch + + +@dataclass(order=True) +class LengthItem: + """Data class for lengths""" + + length: int + data: Any = field(compare=False) + + +def total_length_with_padding(lengths): + """Determines how long would batch be (with padding)""" + return len(lengths) * max(lengths) + + +def padding_ratio(lengths): + """Determines how much of batch is padding.""" + return 1.0 - sum(lengths) / total_length_with_padding(lengths) + + +@dataclass(order=True) +class RatioIndex: + "Data class for Ratio." + + ratio: float + index: int + + +def indices_around_random_pivot( + databuffer, + target_batch_numel, + max_batch_size=None, + max_batch_numel=None, + max_padding_ratio=0.2, + randint_generator=random.randint, +): + """Random pivot sampler_fn for dynamic_bucketed_batch + + Create a batch around a random pivot index in the sorted buffer + + This works on the databuffer which is assumed to be in sorted order. An + index is chosen at random. This starts the window of indices: at first, + only the randomly chosen pivot index is included. The window of indices is + grown one-index-at-a-time, picking either the index to the right of the + window, or the index to the left, picking the index that would increase the + padding ratio the least, and making sure the batch wouldn't exceed the + maximum batch length nor the maximum padding ratio. + + Arguments + --------- + databuffer : list + Sorted list of LengthItems + target_batch_numel : int + Target of total batch length including padding, which is simply computed + as batch size * length of longest example. This function aims to return + the batch as soon as the gathered length exceeds this. If some limits + are encountered first, this may not be satisfied. + max_batch_size : None, int + Maximum number of examples to include in the batch, or None to not limit + by number of examples. + max_batch_numel : None, int + Maximum of total batch length including padding, which is simply computed + as batch size * length of longest example. + max_padding_ratio : float + Each batch can have at most this much devoted to padding. + randint_generator : generator + Provide a generator to get reproducible results. + + Returns + ------- + indices : list + A list of consecutive indices. + """ + bufferlen = len(databuffer) + if max_batch_size is None: + max_batch_size = bufferlen + # Choose pivot: + min_index = max_index = randint_generator(0, bufferlen - 1) + lengths = [databuffer[min_index].length] + + # Define index filtering function: + def possibly_consider(index, to_consider): + """Adds an index to the to_consider list, if the index passes all + requirements.""" + if index < 0 or index >= len(databuffer): + return + consideree = databuffer[index] + updated_lengths = [consideree.length] + lengths + if max_batch_numel is not None: + updated_total = total_length_with_padding(updated_lengths) + if updated_total > max_batch_numel: + return + updated_ratio = padding_ratio(updated_lengths) + if max_padding_ratio is not None and updated_ratio > max_padding_ratio: + return + to_consider.append(RatioIndex(updated_ratio, index)) + + # Loop till the target length is exceeded or max batch size is hit: + while ( + max_index + 1 - min_index < max_batch_size + and total_length_with_padding(lengths) < target_batch_numel + ): + # Consider indices to the left and to the right, if they + # pass the requirements: + to_consider = [] + possibly_consider(min_index - 1, to_consider) + possibly_consider(max_index + 1, to_consider) + # If neither pass the requirements, then we must return the batch + # as it is now (there can be no better addition): + if not to_consider: + break + # Pick the index that minimizes the padding ratio increase: + to_add = min(to_consider) + min_index = min(min_index, to_add.index) + max_index = max(max_index, to_add.index) + lengths.append(databuffer[to_add.index].length) + return list(range(min_index, max_index + 1)) + + +def dynamic_bucketed_batch( + data, + len_key=None, + len_fn=len, + min_sample_len=None, + max_sample_len=None, + buffersize=1024, + collate_fn=PaddedBatch, + sampler_fn=indices_around_random_pivot, + sampler_kwargs={}, + drop_end=False, +): + """Produce batches from a sorted buffer + + This function keeps a sorted buffer of the incoming samples. + The samples can be filtered for min/max length. + An external sampler is used to choose samples for each batch, + which allows different dynamic batching algorithms to be used. + + Arguments + --------- + data : iterable + An iterable source of samples, such as an IterableDataset. + len_key : str, None + The key in the sample dict to use to fetch the length of the sample, or + None if no key should be used. + len_fn : callable + Called with sample[len_key] if len_key is not None, else sample. Needs + to return the sample length as an integer. + min_sample_len : int, None + Discard samples with length lower than this. If None, no minimum is + applied. + max_sample_len : int, None + Discard samples with length larger than this. If None, no maximum is + applied. + buffersize : int + The size of the internal sorted buffer. The buffer is always filled up + before yielding a batch of samples. + collate_fn : callable + Called with a list of samples. This should return a batch. By default, using + the SpeechBrain PaddedBatch class, which works for dict-like samples, and + pads any tensors. + sampler_fn : callable + Called with the sorted data buffer. Needs to return a list of indices, which + make up the next batch. By default using ``indices_around_random_pivot`` + sampler_kwargs : dict + Keyword arguments, passed to sampler_fn. + drop_end : bool + After the data stream is exhausted, should batches be made until the data + buffer is exhausted, or should the rest of the buffer be discarded. Without + new samples, the last batches might not be efficient to process. + Note: you can use ``.repeat`` on `webdataset` IterableDatasets to never + run out of new samples, and then use + `speechbrain.dataio.dataloader.LoopedLoader` to set a nominal epoch length. + + Yields + ------ + Batches + """ + databuffer = [] + if sampler_kwargs: + sampler_fn = partial(sampler_fn, **sampler_kwargs) + for sample in data: + # Length fetching interface has multiple valid call signatures: + if len_key is not None and len_fn is not None: + length = len_fn(sample[len_key]) + elif len_key is not None: + length = sample[len_key] + elif len_fn is not None: + length = len_fn(sample) + else: + raise ValueError("Must specify at least one of len_key or len_fn") + # Possibly filter by length: + if (min_sample_len is not None and length < min_sample_len) or ( + max_sample_len is not None and length > max_sample_len + ): + # Drop sample + continue + item = LengthItem(length, sample) + # bisect.insort inserts in sorted order. + # This should be a good way to maintain a sorted list, + # but perhaps simply filling up the buffer and calling .sort() + # could be good as well (Python's sort leverages already sorted segments) + bisect.insort(databuffer, item) + if len(databuffer) == buffersize: + indices = sampler_fn(databuffer) + batch_list = [] + # popping from highest to lowest is safe + for i in sorted(indices, reverse=True): + item = databuffer.pop(i) + batch_list.append(item.data) + yield collate_fn(batch_list) + # Data stream was exhausted. Data buffer is relatively full at first, + # but cannot be replenished, so batches might not be efficiently produced. + # Either stop, or exhaust buffer. + if not drop_end: + while databuffer: + indices = sampler_fn(databuffer) + batch_list = [] + for i in sorted(indices, reverse=True): + item = databuffer.pop(i) + batch_list.append(item.data) + yield collate_fn(batch_list) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/legacy.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/legacy.py new file mode 100644 index 000000000..ffebb9888 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/legacy.py @@ -0,0 +1,321 @@ +"""SpeechBrain Extended CSV Compatibility.""" + +import collections +import csv +import pickle +import re + +import torch + +from speechbrain.dataio import audio_io +from speechbrain.dataio.dataset import DynamicItemDataset +from speechbrain.utils.logger import get_logger + +logger = get_logger(__name__) + + +TORCHAUDIO_FORMATS = ["wav", "flac", "aac", "ogg", "flac", "mp3"] +ITEM_POSTFIX = "_data" + +CSVItem = collections.namedtuple("CSVItem", ["data", "format", "opts"]) +CSVItem.__doc__ = """The Legacy Extended CSV Data item triplet""" + + +class ExtendedCSVDataset(DynamicItemDataset): + """Extended CSV compatibility for DynamicItemDataset. + + Uses the SpeechBrain Extended CSV data format, where the CSV must have an + 'ID' and 'duration' fields. + + The rest of the fields come in triplets: + ``, _format, _opts`` + + These add a _sb_data item in the dict. Additionally, a basic + DynamicItem (see DynamicItemDataset) is created, which loads the _sb_data + item. + + Bash-like string replacements with $to_replace are supported. + + NOTE + ---- + Mapping from legacy interface: + + - csv_file -> csvpath + - sentence_sorting -> sorting, and "random" is not supported, use e.g. + ``make_dataloader(..., shuffle = (sorting=="random"))`` + - avoid_if_shorter_than -> min_duration + - avoid_if_longer_than -> max_duration + - csv_read -> output_keys, and if you want IDs add "id" as key + + Arguments + --------- + csvpath : str, path + Path to extended CSV. + replacements : dict + Used for Bash-like $-prefixed substitution, + e.g. ``{"data_folder": "/home/speechbrain/data"}``, which would + transform `$data_folder/utt1.wav` into `/home/speechbrain/data/utt1.wav` + sorting : {"original", "ascending", "descending"} + Keep CSV order, or sort ascending or descending by duration. + min_duration : float, int + Minimum duration in seconds. Discards other entries. + max_duration : float, int + Maximum duration in seconds. Discards other entries. + dynamic_items : list + Configuration for extra dynamic items produced when fetching an + example. List of DynamicItems or dicts with keys:: + func: # To be called + takes: # key or list of keys of args this takes + provides: key # key or list of keys that this provides + NOTE: A dynamic item is automatically added for each CSV data-triplet + output_keys : list, None + The list of output keys to produce. You can refer to the names of the + CSV data-triplets. E.G. if the CSV has: wav,wav_format,wav_opts, + then the Dataset has a dynamic item output available with key ``"wav"`` + NOTE: If None, read all existing. + """ + + def __init__( + self, + csvpath, + replacements={}, + sorting="original", + min_duration=0, + max_duration=36000, + dynamic_items=[], + output_keys=[], + ): + if sorting not in ["original", "ascending", "descending"]: + clsname = self.__class__.__name__ + raise ValueError(f"{clsname} doesn't support {sorting} sorting") + # Load the CSV, init class + data, di_to_add, data_names = load_sb_extended_csv( + csvpath, replacements + ) + super().__init__(data, dynamic_items, output_keys) + self.pipeline.add_dynamic_items(di_to_add) + # Handle filtering, sorting: + reverse = False + sort_key = None + if sorting == "ascending" or "descending": + sort_key = "duration" + if sorting == "descending": + reverse = True + filtered_sorted_ids = self._filtered_sorted_ids( + key_min_value={"duration": min_duration}, + key_max_value={"duration": max_duration}, + sort_key=sort_key, + reverse=reverse, + ) + self.data_ids = filtered_sorted_ids + # Handle None output_keys (differently than Base) + if not output_keys: + self.set_output_keys(data_names) + + +def load_sb_extended_csv(csv_path, replacements=None): + """Loads SB Extended CSV and formats string values. + + Uses the SpeechBrain Extended CSV data format, where the + CSV must have an 'ID' and 'duration' fields. + + The rest of the fields come in triplets: + ``, _format, _opts``. + + These add a _sb_data item in the dict. Additionally, a + basic DynamicItem (see DynamicItemDataset) is created, which + loads the _sb_data item. + + Bash-like string replacements with $to_replace are supported. + + This format has its restriction, but they allow some tasks to + have loading specified by the CSV. + + Arguments + --------- + csv_path : str + Path to the CSV file. + replacements : dict + Optional dict: + e.g. ``{"data_folder": "/home/speechbrain/data"}`` + This is used to recursively format all string values in the data. + + Returns + ------- + dict + CSV data with replacements applied. + list + List of DynamicItems to add in DynamicItemDataset. + + """ + if replacements is None: + replacements = {} + with open(csv_path, newline="", encoding="utf-8") as csvfile: + result = {} + reader = csv.DictReader(csvfile, skipinitialspace=True) + variable_finder = re.compile(r"\$([\w.]+)") + if not reader.fieldnames[0] == "ID": + raise KeyError( + "CSV has to have an 'ID' field, with unique ids" + " for all data points" + ) + if not reader.fieldnames[1] == "duration": + raise KeyError( + "CSV has to have an 'duration' field, " + "with the length of the data point in seconds." + ) + if not len(reader.fieldnames[2:]) % 3 == 0: + raise ValueError( + "All named fields must have 3 entries: " + ", _format, _opts" + ) + names = reader.fieldnames[2::3] + for row in reader: + # Make a triplet for each name + data_point = {} + # ID: + data_id = row["ID"] + del row["ID"] # This is used as a key in result, instead. + # Duration: + data_point["duration"] = float(row["duration"]) + del row["duration"] # This is handled specially. + if data_id in result: + raise ValueError(f"Duplicate id: {data_id}") + # Replacements: + # Only need to run these in the actual data, + # not in _opts, _format + for key, value in list(row.items())[::3]: + try: + row[key] = variable_finder.sub( + lambda match: replacements[match[1]], value + ) + except KeyError: + raise KeyError( + f"The item {value} requires replacements " + "which were not supplied." + ) + for i, name in enumerate(names): + triplet = CSVItem(*list(row.values())[i * 3 : i * 3 + 3]) + data_point[name + ITEM_POSTFIX] = triplet + result[data_id] = data_point + # Make a DynamicItem for each CSV entry + # _read_csv_item delegates reading to further + dynamic_items_to_add = [] + for name in names: + di = { + "func": _read_csv_item, + "takes": name + ITEM_POSTFIX, + "provides": name, + } + dynamic_items_to_add.append(di) + return result, dynamic_items_to_add, names + + +def _read_csv_item(item): + """Reads the different formats supported in SB Extended CSV. + + Delegates to the relevant functions. + """ + opts = _parse_csv_item_opts(item.opts) + if item.format in TORCHAUDIO_FORMATS: + audio, _ = audio_io.load(item.data) + return audio.squeeze(0) + elif item.format == "pkl": + return read_pkl(item.data, opts) + elif item.format == "string": + # Just implement string reading here. + # NOTE: No longer supporting + # lab2ind mapping like before. + # Try decoding string + string = item.data + try: + string = string.decode("utf-8") + except AttributeError: + pass + # Splitting elements with ' ' + string = string.split(" ") + return string + else: + raise TypeError(f"Don't know how to read {item.format}") + + +def _parse_csv_item_opts(entry): + """Parse the _opts field in a SB Extended CSV item.""" + # Accepting even slightly weirdly formatted entries: + entry = entry.strip() + if len(entry) == 0: + return {} + opts = {} + for opt in entry.split(" "): + opt_name, opt_val = opt.split(":") + opts[opt_name] = opt_val + return opts + + +def read_pkl(file, data_options=None, lab2ind=None): + """This function reads tensors store in pkl format. + + Arguments + --------- + file : str + The path to file to read. + data_options : dict, optional + A dictionary containing options for the reader. + lab2ind : dict, optional + Mapping from label to integer indices. + + Returns + ------- + numpy.array + The array containing the read signal. + """ + + if data_options is None: + data_options = {} + # Trying to read data + try: + with open(file, "rb") as f: + pkl_element = pickle.load(f) + except pickle.UnpicklingError: + err_msg = "cannot read the pkl file %s" % (file) + raise ValueError(err_msg) + + type_ok = False + + if isinstance(pkl_element, list): + if isinstance(pkl_element[0], float): + tensor = torch.FloatTensor(pkl_element) + type_ok = True + + if isinstance(pkl_element[0], int): + tensor = torch.LongTensor(pkl_element) + type_ok = True + + if isinstance(pkl_element[0], str): + # convert string to integer as specified in self.label_dict + if lab2ind is not None: + for index, val in enumerate(pkl_element): + pkl_element[index] = lab2ind[val] + + tensor = torch.LongTensor(pkl_element) + type_ok = True + + if not type_ok: + err_msg = ( + "The pkl file %s can only contain list of integers, " + "floats, or strings. Got %s" + ) % (file, type(pkl_element[0])) + raise ValueError(err_msg) + else: + tensor = pkl_element + + tensor_type = tensor.dtype + + # Conversion to 32 bit (if needed) + if tensor_type == torch.float64: + tensor = tensor.to(torch.float32) + + if tensor_type == torch.int64: + tensor = tensor.to(torch.int32) + + return tensor diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/preprocess.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/preprocess.py new file mode 100644 index 000000000..85e8d45ba --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/preprocess.py @@ -0,0 +1,82 @@ +"""Preprocessors for audio""" + +import torch + +from speechbrain.augment.time_domain import Resample + + +class AudioNormalizer: + """Normalizes audio into a standard format + + Arguments + --------- + sample_rate : int + The sampling rate to which the incoming signals should be converted. + mix : {"avg-to-mono", "keep"} + "avg-to-mono" - add all channels together and normalize by number of + channels. This also removes the channel dimension, resulting in [time] + format tensor. + "keep" - don't normalize channel information + + Example + ------- + >>> from speechbrain.dataio import audio_io + >>> example_file = ( + ... "tests/samples/multi-mic/speech_-0.82918_0.55279_-0.082918.flac" + ... ) + >>> signal, sr = audio_io.load(example_file, channels_first=False) + >>> normalizer = AudioNormalizer(sample_rate=8000) + >>> normalized = normalizer(signal, sr) + >>> signal.shape + torch.Size([160000, 4]) + >>> normalized.shape + torch.Size([80000]) + + NOTE + ---- + This will also upsample audio. However, upsampling cannot produce meaningful + information in the bandwidth which it adds. Generally models will not work + well for upsampled data if they have not specifically been trained to do so. + """ + + def __init__(self, sample_rate=16000, mix="avg-to-mono"): + self.sample_rate = sample_rate + if mix not in ["avg-to-mono", "keep"]: + raise ValueError(f"Unexpected mixing configuration {mix}") + self.mix = mix + self._cached_resamplers = {} + + def __call__(self, audio, sample_rate): + """Perform normalization + + Arguments + --------- + audio : torch.Tensor + The input waveform torch tensor. Assuming [time, channels], + or [time]. + sample_rate : int + Rate the audio was sampled at. + + Returns + ------- + audio : torch.Tensor + Channel- and sample-rate-normalized audio. + """ + if sample_rate not in self._cached_resamplers: + # Create a Resample instance from this newly seen SR to internal SR + self._cached_resamplers[sample_rate] = Resample( + sample_rate, self.sample_rate + ) + resampler = self._cached_resamplers[sample_rate] + resampled = resampler(audio.unsqueeze(0)).squeeze(0) + return self._mix(resampled) + + def _mix(self, audio): + """Handle channel mixing""" + flat_input = audio.dim() == 1 + if self.mix == "avg-to-mono": + if flat_input: + return audio + return torch.mean(audio, 1) + if self.mix == "keep": + return audio diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/sampler.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/sampler.py new file mode 100644 index 000000000..8fa862b2d --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/sampler.py @@ -0,0 +1,845 @@ +"""PyTorch compatible samplers. + +These determine the order of iteration through a dataset. + +Authors: + * Aku Rouhe 2020 + * Samuele Cornell 2020 + * Ralf Leibold 2020 + * Artem Ploujnikov 2021 + * Andreas Nautsch 2021, 2023 + * Adel Moumen 2023 +""" + +from collections import Counter +from operator import itemgetter +from typing import List, Optional, Union + +import numpy as np +import torch +from scipy.stats import lognorm +from torch.utils.data import ( + DistributedSampler, + RandomSampler, + Sampler, + WeightedRandomSampler, +) + +from speechbrain.dataio.dataset import DynamicItemDataset +from speechbrain.utils.logger import get_logger + +logger = get_logger(__name__) + + +class ReproducibleRandomSampler(RandomSampler): + """A modification of RandomSampler which always returns the same values. + + Also look at `torch.utils.data.RandomSampler`. This has mostly + the same behaviour and arguments, except for adding 'seed' and 'epoch' and + not supporting 'generator'. + + Note + ---- + Call `set_epoch` before every epoch. Otherwise, the sampler will produce the + same sequence of indices every epoch. + + Arguments + --------- + data_source : Dataset + The data source to sample indices for. + seed : int + The base seed to use for the random number generator. It is recommended + to use a value which has a good mix of 0 and 1 bits. + epoch : int + The epoch to start at. + **kwargs : dict + Arguments to pass to parent class. + + Example + ------- + >>> import torch + >>> from speechbrain.utils.checkpoints import Checkpointer + >>> from speechbrain.dataio.dataloader import SaveableDataLoader + >>> # An example "dataset" + >>> dataset = torch.arange(10).unsqueeze(1) + >>> # Create the random sampler: + >>> sampler = ReproducibleRandomSampler(dataset) + >>> dataloader = SaveableDataLoader(dataset, sampler=sampler, num_workers=3) + >>> # Setup the checkpointer. + >>> # Note that the sampler doesn't need to be saved itself. + >>> tmpdir = getfixture("tmpdir") + >>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader}) + >>> # Iterate: + >>> subset = [] + >>> for i, data_point in enumerate(dataloader): + ... # Say you save a checkpoint on the fourth batch: + ... if i == 3: + ... _ = checkpointer.save_checkpoint(end_of_epoch=False) + ... # So let's save the numbers you would get if you continue + ... if i >= 4: + ... subset.append(data_point.item()) + >>> # What if instead you had to restart the experiment? + >>> new_sampler = ReproducibleRandomSampler(dataset) + >>> new_dataloader = SaveableDataLoader( + ... dataset, sampler=new_sampler, num_workers=3 + ... ) + >>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader}) + >>> _ = new_checkpointer.recover_if_possible() + >>> # You'll get the same random order again: + >>> new_subset = [data_point.item() for data_point in new_dataloader] + >>> assert subset == new_subset + + """ + + def __init__(self, data_source, seed=563375142, epoch=0, **kwargs): + if "generator" in kwargs: + MSG = ( + "Cannot give a separate generator when using " + + "ReproducibleRandomSampler" + ) + raise ValueError(MSG) + super().__init__(data_source, **kwargs) + self.seed = int(seed) + self.epoch = epoch + self.generator = torch.Generator() + + def set_epoch(self, epoch): + """ + You can also just access self.epoch, but we maintain this interface + to mirror torch.utils.data.distributed.DistributedSampler + """ + self.epoch = epoch + + def __iter__(self): + self.generator.manual_seed(self.seed + self.epoch) + return super().__iter__() + + +class ReproducibleWeightedRandomSampler(WeightedRandomSampler): + """A reproducible modification of WeightedRandomSampler. + + Also look at `torch.utils.data.WeightedRandomSampler`. This has the + the same behaviour and arguments, except for adding 'seed' and 'epoch' and + not supporting 'generator'. + + Note + ---- + Call `set_epoch` before every epoch. Otherwise, the sampler will produce the + same sequence of indices every epoch. + + Arguments + --------- + weights : sequence of float + Weights for each index. Doesn't need to sum to one. + num_samples : int + Number of samples to draw + replacement : bool + To draw with replacement or not (within an epoch of num_samples). + seed : int + The base seed to use for the random number generator. It is recommended + to use a value which has a good mix of 0 and 1 bits. + epoch : int + The epoch to start at. + **kwargs : dict + Arguments to pass to parent class. + + Example + ------- + >>> a = ReproducibleWeightedRandomSampler( + ... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True + ... ) + >>> b = ReproducibleWeightedRandomSampler( + ... [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True + ... ) + >>> list(a) + [3, 1, 4, 4, 4] + >>> list(b) + [3, 1, 4, 4, 4] + >>> a.set_epoch(1) + >>> list(a) + [4, 5, 4, 4, 3] + >>> b.set_epoch(1) + >>> list(b) + [4, 5, 4, 4, 3] + + + """ + + def __init__( + self, + weights, + num_samples, + replacement, + seed=129491412, + epoch=0, + **kwargs, + ): + if "generator" in kwargs: + MSG = ( + "Cannot give a separate generator when using " + + "ReproducibleRandomSampler" + ) + raise ValueError(MSG) + super().__init__(weights, num_samples, replacement, **kwargs) + self.seed = int(seed) + self.epoch = epoch + self.generator = torch.Generator() + + def set_epoch(self, epoch): + """ + You can also just access self.epoch, but we maintain this interface + to mirror torch.utils.data.distributed.DistributedSampler + """ + self.epoch = epoch + + def __iter__(self): + self.generator.manual_seed(self.seed + self.epoch) + return super().__iter__() + + +class ConcatDatasetBatchSampler(Sampler): + """This sampler is built to work with a standard Pytorch ConcatDataset. + + It is used to retrieve elements from the different concatenated datasets placing them in the same batch + with proportion specified by batch_sizes, e.g 8, 16 means each batch will + be of 24 elements with the first 8 belonging to the first dataset in ConcatDataset + object and the last 16 to the second. + More than two datasets are supported, in that case you need to provide 3 batch + sizes. + + Note + ---- + Batched are drawn from the datasets till the one with smallest length is exhausted. + Thus number of examples in your training epoch is dictated by the dataset + whose length is the smallest. + + + Arguments + --------- + samplers : list or tuple + a list or tuple of pytorch samplers + batch_sizes: list + Batch sizes. + epoch : int + The epoch to start at. + + Example + ------- + >>> import torch + >>> from speechbrain.dataio.sampler import ( + ... ConcatDatasetBatchSampler, + ... ReproducibleRandomSampler, + ... ) + >>> from speechbrain.dataio.sampler import ReproducibleRandomSampler + >>> from speechbrain.dataio.dataloader import SaveableDataLoader + >>> # example "datasets" + >>> dataset1 = torch.arange(0, 10).unsqueeze(1) + >>> dataset2 = torch.arange(20, 40).unsqueeze(1) + >>> tot_dataset = torch.utils.data.ConcatDataset([dataset1, dataset2]) + >>> sampler1 = ReproducibleRandomSampler(dataset1) + >>> sampler2 = ReproducibleRandomSampler(dataset2) + >>> tot_sampler = ConcatDatasetBatchSampler([sampler1, sampler2], [2, 4]) + >>> dataloader = SaveableDataLoader( + ... tot_dataset, batch_sampler=tot_sampler, num_workers=3 + ... ) + >>> for data_point in dataloader: + ... assert len(data_point) == 6 + ... for i in range(2): + ... assert data_point[i] in [x for x in range(0, 10)] + ... for i in range(2, 4): + ... assert data_point[i] in [x for x in range(10, 40)] + """ + + def __init__( + self, samplers, batch_sizes: Union[tuple, list], epoch=0 + ) -> None: + if not isinstance(samplers, (list, tuple)): + raise ValueError( + "samplers should be a list or tuple of Pytorch Samplers, " + f"but got samplers={samplers}" + ) + + if not isinstance(batch_sizes, (list, tuple)): + raise ValueError( + "batch_sizes should be a list or tuple of integers, " + f"but got batch_sizes={batch_sizes}" + ) + + if not len(batch_sizes) == len(samplers): + raise ValueError( + "batch_sizes and samplers should be have same length" + ) + + self.batch_sizes = batch_sizes + self.samplers = samplers + self.offsets = [0] + np.cumsum( + [len(x) for x in self.samplers] + ).tolist()[:-1] + + self.epoch = epoch + self.set_epoch(self.epoch) + + def _iter_one_dataset(self, c_batch_size, c_sampler, c_offset): + batch = [] + for idx in c_sampler: + batch.append(c_offset + idx) + if len(batch) == c_batch_size: + yield batch + + def set_epoch(self, epoch): + """You can also just access self.epoch, but we maintain this interface + to mirror ``torch.utils.data.distributed.DistributedSampler``. + """ + if hasattr(self.samplers[0], "epoch"): + for s in self.samplers: + s.set_epoch(epoch) + + def __iter__(self): + iterators = [iter(i) for i in self.samplers] + tot_batch = [] + + for b_num in range(len(self)): + for samp_idx in range(len(self.samplers)): + c_batch = [] + while len(c_batch) < self.batch_sizes[samp_idx]: + c_batch.append( + self.offsets[samp_idx] + next(iterators[samp_idx]) + ) + tot_batch.extend(c_batch) + yield tot_batch + tot_batch = [] + + def __len__(self) -> int: + min_len = float("inf") + for idx, sampler in enumerate(self.samplers): + c_len = len(sampler) // self.batch_sizes[idx] + min_len = min(c_len, min_len) + + return int(min_len) + + +class DynamicBatchSampler(Sampler): + """This BatchSampler batches examples together by grouping them by their length. + + Every example in the batch have approximately the same length and + thus padding is minimized. + This enables faster training on datasets + where length of examples can vary significantly (e.g Librispeech). + Inspired by: https://www.tensorflow.org/api_docs/python/tf/data/experimental/bucket_by_sequence_length + + Dynamic batching is performed by specifying a max_batch_length which is the + upper limit for the sum of the length of examples in a batch: + e.g., if ex1 has length 4, ex2 length 5 and if max_batch_length is set to 6 + ex1 and ex2 will be placed, alone, in two distinct batches. + + Length for each example can be obtained in two manners. + If the input dataset is a DynamicItemDataset it can be obtained by specifying a + length_func. Default assumes a "duration" entry is in the annotation. + Length for each example can also be passed to this class upon instantiation + by specifying a list containing the length for each example and passing it to + lengths_list. + + Examples are grouped together by defining a set of possible discrete intervals + (buckets). Examples whose length fall into these intervals can be batched together. + + The number of buckets can be specified by using the arg num_buckets. + There is usually an optimal range for the value of this argument. + + If num_buckets == 1, all examples can be batched together. You have maximum randomization + but your training speed will be slower due to the fact that a large amount of the values will be padding + as long and short examples can be batched together. + As the number of buckets grows only examples with similar + length can be grouped together. + This trades-off speed with randomization. + TLDR: Low number -> better randomization, High number -> faster training. + NOTE THAT: if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size + will be small impacting training speed and possibly performance. + + The buckets can also be specified by passing a list to the bucket_boundaries + argument instead of specifying a left_bucket_length and a bucket_length_multiplier. + + Example + ------- + >>> import torch + >>> import speechbrain as sb + >>> from speechbrain.dataio.sampler import DynamicBatchSampler + >>> from speechbrain.dataio.dataset import DynamicItemDataset + >>> from speechbrain.dataio.dataloader import SaveableDataLoader + >>> from speechbrain.dataio.batch import PaddedBatch + >>> import numpy as np + >>> item_lengths = sorted([np.random.randint(10, 100) for x in range(20)]) + >>> dataset = { + ... "ex_{}".format(x): {"wav": torch.randn(x)} for x in item_lengths + ... } + >>> dataset = DynamicItemDataset(dataset) + >>> dataset.set_output_keys(["wav"]) + >>> length_func = lambda x: len(x) # trivial in this example + >>> bsampler = DynamicBatchSampler( + ... dataset, + ... 20, + ... 4, + ... length_func, + ... shuffle=False, + ... batch_ordering="descending", + ... ) + >>> dataloader = SaveableDataLoader( + ... dataset, batch_sampler=bsampler, collate_fn=PaddedBatch + ... ) + >>> for i, b in enumerate(dataloader): + ... data, length = b["wav"] + >>> assert data.shape[-1] == max(item_lengths) + + Arguments + --------- + dataset : torch.utils.data.Dataset + Pytorch Dataset from which elements will be sampled. + max_batch_length : int + Upper limit for the sum of the length of examples in a batch. + Should be chosen based on your GPU memory. + num_buckets : int + Number of discrete buckets used to group examples together. + If num_buckets == 1, all examples can be batched together. As the number of buckets grows only examples with similar + length can be grouped together. This trades-off speed with randomization. + Low number -> better randomization, High number -> faster training. + However if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size + will be small impacting training speed and possibly performance. + NOTE: you have either to specify manually the bucket_boundaries or the number of buckets. + length_func : callable + Function used to get length of each example from the dataset. + This argument can be used only when the dataset is a Speechbrain DynamicItemDataset object. + Can be anything: e.g. lambda x: x["duration"]*16000 returns number of samples + if duration key in the annotation is in seconds and the file has 16kHz sampling freq. + shuffle : bool + Whether or not shuffle examples between each epoch. + batch_ordering : string + If ``random``, batches are randomly permuted; otherwise ``ascending`` or ``descending`` sorted by length. + max_batch_ex: int + If set, it limits the maximum number of examples that can be in a batch superseding max_batch_length + in instances where the amount of examples will exceed the value specified here. + E.g. you have a lot of short examples and the batch size for those will be too high, you can use this argument + to limit the batch size for these short examples. + bucket_boundaries : list + Overrides bucket_length_multiplier and left_bucket_length by specifying manually + the buckets right boundaries. + lengths_list: list + Overrides length_func by passing a list containing the length of each example + in the dataset. This argument must be set when the dataset is a plain + Pytorch Dataset object and not a DynamicItemDataset object as length_func + cannot be used on Pytorch Datasets. + seed : int + Random seed. + epoch : int + The epoch to start at. + drop_last : bool + If ``True``, the sampler will drop the last examples which + have not been grouped. + verbose: bool + If ``True``, log also the stats for each batch at the first epoch. + """ + + def __init__( + self, + dataset, + max_batch_length: int, + num_buckets: Optional[int] = None, + length_func=lambda x: x["duration"], + shuffle: bool = True, + batch_ordering: str = "random", + max_batch_ex: Optional[int] = None, + bucket_boundaries: List[int] = [], + lengths_list: Optional[list[int]] = None, + seed: int = 42, + epoch: int = 0, + drop_last: bool = False, + verbose: bool = False, + ): + self._dataset = dataset + self._ex_lengths = {} + self.verbose = verbose + + # We do not put a default on num_buckets to encourage users to play with this parameter + if num_buckets is None and len(bucket_boundaries) == 0: + raise RuntimeError( + "Please specify either num_buckets or bucket boundaries." + "Check the docs, and/or the tutorial !" + ) + + if lengths_list is not None: + # take length of examples from this argument and bypass length_key + for indx in range(len(lengths_list)): + self._ex_lengths[str(indx)] = lengths_list[indx] + else: + # use length func + if not isinstance(dataset, DynamicItemDataset): + raise NotImplementedError( + "Dataset should be a Speechbrain DynamicItemDataset when using length function" + ) + for indx in range(len(self._dataset)): + self._ex_lengths[str(indx)] = length_func( + self._dataset.data[self._dataset.data_ids[indx]] + ) + + if len(bucket_boundaries) > 0: + if not all([x >= 0 for x in bucket_boundaries]): + raise ValueError( + "All elements in bucket boundaries should be non-negative (>= 0)." + ) + if not len(set(bucket_boundaries)) == len(bucket_boundaries): + raise ValueError( + "Bucket_boundaries should not contain duplicates." + ) + np.testing.assert_array_equal( + np.array(bucket_boundaries), + np.array(sorted(bucket_boundaries)), + err_msg="The arg bucket_boundaries should be an ascending sorted list of non negative values values!", + ) + self._bucket_boundaries = np.array(sorted(bucket_boundaries)) + else: + # use num_buckets + self._bucket_boundaries = np.array( + self._get_boundaries_through_warping( + max_batch_length=max_batch_length, + num_quantiles=num_buckets, + ) + ) + + self._max_batch_length = max_batch_length + self._shuffle_ex = shuffle + self._batch_ordering = batch_ordering + self._seed = seed + self._drop_last = drop_last + if max_batch_ex is None: + max_batch_ex = np.inf + self._max_batch_ex = max_batch_ex + # Calculate bucket lengths - how often does one bucket boundary fit into max_batch_length? + self._bucket_lens = [ + min( + self._max_batch_ex, # tops max_duration_per_batch + max( + 1, # and at least 1 + int(self._max_batch_length / self._bucket_boundaries[i]), + ), + ) + for i in range(len(self._bucket_boundaries)) + ] + [1] + self._epoch = epoch + self._generate_batches() + + def get_durations(self, batch): + """Gets durations of the elements in the batch.""" + return [self._ex_lengths[str(idx)] for idx in batch] + + def _get_boundaries_through_warping( + self, + max_batch_length: int, + num_quantiles: int, + ) -> List[int]: + # NOTE: the following lines do not cover that there is only one example in the dataset + # warp frames (duration) distribution of train data + logger.info("Batch quantisation in latent space") + # linspace set-up + num_boundaries = num_quantiles + 1 + # create latent linearly equal spaced buckets + latent_boundaries = np.linspace( + 1 / num_boundaries, + num_quantiles / num_boundaries, + num_quantiles, + ) + # get quantiles using lognormal distribution + quantiles = lognorm.ppf(latent_boundaries, 1) + # scale up to to max_batch_length + bucket_boundaries = quantiles * max_batch_length / quantiles[-1] + # compute resulting bucket length multipliers + length_multipliers = [ + bucket_boundaries[x + 1] / bucket_boundaries[x] + for x in range(num_quantiles - 1) + ] + # logging + logger.debug( + "Latent bucket boundary - buckets: {} - length multipliers: {}".format( + list(map("{:.2f}".format, bucket_boundaries)), + list(map("{:.2f}".format, length_multipliers)), + ) + ) + return sorted(bucket_boundaries) + + def _permute_batches(self): + if self._batch_ordering == "random": + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self._seed + self._epoch) + sampler = torch.randperm(len(self._batches), generator=g).tolist() # type: ignore + tmp = [] + for idx in sampler: + tmp.append(self._batches[idx]) + self._batches = tmp + + elif self._batch_ordering == "ascending": + self._batches = sorted( + self._batches, + key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), + ) + elif self._batch_ordering == "descending": + self._batches = sorted( + self._batches, + key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), + reverse=True, + ) + else: + raise NotImplementedError + + def _generate_batches(self): + logger.info("DynamicBatchSampler: Generating dynamic batches") + if self._shuffle_ex: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self._seed + self._epoch) + sampler = torch.randperm(len(self._dataset), generator=g).tolist() # type: ignore + else: + # take examples as they are: e.g. they have been sorted + sampler = range(len(self._dataset)) # type: ignore + + self._batches = [] + bucket_batches = [[] for i in self._bucket_lens] + + stats_tracker = [ + {"min": np.inf, "max": -np.inf, "tot": 0, "n_ex": 0} + for i in self._bucket_lens + ] + + for idx in sampler: + # length of pre-sampled audio + item_len = self._ex_lengths[str(idx)] + # bucket to fill up most padding + bucket_id = np.searchsorted(self._bucket_boundaries, item_len) + # fill audio's duration into that bucket + bucket_batches[bucket_id].append(idx) + + stats_tracker[bucket_id]["min"] = min( + stats_tracker[bucket_id]["min"], item_len + ) + stats_tracker[bucket_id]["max"] = max( + stats_tracker[bucket_id]["max"], item_len + ) + stats_tracker[bucket_id]["tot"] += item_len + stats_tracker[bucket_id]["n_ex"] += 1 + # track #samples - why not duration/#frames; rounded up? + # keep track of durations, if necessary + + if ( + len(bucket_batches[bucket_id]) >= self._bucket_lens[bucket_id] + or len(bucket_batches[bucket_id]) >= self._max_batch_ex + ): + self._batches.append(bucket_batches[bucket_id]) + bucket_batches[bucket_id] = [] + # keep track of durations + + # Dump remaining batches + if not self._drop_last: + for batch in bucket_batches: + if batch: + self._batches.append(batch) + + self._permute_batches() # possibly reorder batches + + if self._epoch == 0: # only log at first epoch + # frames per batch & their padding remaining + boundaries = [0] + self._bucket_boundaries.tolist() + + for bucket_indx in range(len(self._bucket_boundaries)): + try: + num_batches = stats_tracker[bucket_indx]["tot"] // ( + self._max_batch_length + ) + pad_factor = ( + stats_tracker[bucket_indx]["max"] + - stats_tracker[bucket_indx]["min"] + ) / ( + stats_tracker[bucket_indx]["tot"] + / stats_tracker[bucket_indx]["n_ex"] + ) + except ZeroDivisionError: + num_batches = 0 + pad_factor = 0 + + logger.debug( + ( + "DynamicBatchSampler: Bucket {} with boundary {:.1f}-{:.1f} and " + + "batch_size {}: Num Examples {:.1f}, Num Full Batches {:.3f}, Pad Factor {:.3f}." + ).format( + bucket_indx, + boundaries[bucket_indx], + boundaries[bucket_indx + 1], + self._bucket_lens[bucket_indx], + stats_tracker[bucket_indx]["n_ex"], + num_batches, + pad_factor * 100, + ) + ) + + if self.verbose: + batch_stats = { + "tot_frames": [], + "tot_pad_frames": [], + "pad_%": [], + } + for batch in self._batches: + tot_frames = sum( + [self._ex_lengths[str(idx)] for idx in batch] + ) + batch_stats["tot_frames"].append(tot_frames) + max_frames = max( + [self._ex_lengths[str(idx)] for idx in batch] + ) + tot_pad = sum( + [ + max_frames - self._ex_lengths[str(idx)] + for idx in batch + ] + ) + batch_stats["tot_pad_frames"].append(tot_pad) + batch_stats["pad_%"].append(tot_pad / tot_frames * 100) + + padding_details = "Batch {} with {:.1f} frames with {} files - {:.1f} padding, {:.2f} (%) of total." + padding_details = "DynamicBatchSampler: " + padding_details + for i in range(len(self._batches)): + logger.debug( + padding_details.format( + i, + batch_stats["tot_frames"][i], + len(self._batches[i]), + batch_stats["tot_pad_frames"][i], + batch_stats["pad_%"][i], + ) + ) + + def __iter__(self): + for batch in self._batches: + yield batch + if self._shuffle_ex: # re-generate examples if ex_ordering == "random" + self._generate_batches() + if self._batch_ordering == "random": + # we randomly permute the batches only --> faster + self._permute_batches() + + def set_epoch(self, epoch): + """ + You can also just access self.epoch, but we maintain this interface + to mirror torch.utils.data.distributed.DistributedSampler + """ + self._epoch = epoch + self._generate_batches() + + def __len__(self): + return len(self._batches) + + +# Heavily inspired by Catalyst, which is under Apache 2.0 license. +# https://github.com/catalyst-team/catalyst/blob/51428d7756e62b9b8ee5379f38e9fd576eeb36e5/catalyst/data/sampler.py#L522 +class DistributedSamplerWrapper(DistributedSampler): + """This wrapper allows using any sampler (for example batch) with Distributed Data Parallel (DDP) + correctly. + + Passing blindly the sampler to each DDP process will cause to have access + within each process to all the data in the dataset instead of only a subset + of it which is unique to each process. This wrapper prevents this and + allows to use only a subset of the original data for each process. + + NOTE + ---- + This is is automatically applied to any sampler in the Brain class when DDP + training is used. + """ + + def __init__(self, sampler, *args, **kwargs): + # DistributedSampler only calls len() on dataset + # so a sampler is fine to pass there, as well. + super().__init__(dataset=sampler, *args, **kwargs) + self.sampler = sampler + + def __iter__(self): + # It is easiest to use a random access interface to the wrapped + # sampler's indices, so we just fetch all indices from the wrapped + # sampler + sampler_indices = list(self.sampler.__iter__()) + indices_of_indices = super().__iter__() + # Itemgetter fetches the wrapped sampler indices from the positions + # pointed to by DistributedSampler + return iter(itemgetter(*indices_of_indices)(sampler_indices)) + + def set_epoch(self, epoch): + """Pass set_epoch() through to DistributedSampler and the wrapper one""" + super().set_epoch(epoch) + if hasattr(self.sampler, "set_epoch"): + self.sampler.set_epoch(epoch) + + +class BalancingDataSampler(ReproducibleWeightedRandomSampler): + """A data sampler that takes a single key from the dataset and + ensures an approximately equal distribution by that key + + Arguments + --------- + dataset : DynamicItemDataset + the dataset form which samples will be drawn + key : str + the key from which samples will be taken + num_samples : int + Number of samples to draw + replacement : bool + To draw with replacement or not (within an epoch of num_samples). + seed : int + The base seed to use for the random number generator. It is recommended + to use a value which has a good mix of 0 and 1 bits. + epoch : int + The epoch to start at. + **kwargs : dict + Arguments to pass to parent class. + + Example + ------- + >>> from speechbrain.dataio.sampler import BalancingDataSampler + >>> from speechbrain.dataio.dataset import DynamicItemDataset + >>> sample_data = { + ... 1: {"category": "A", "text": "This is a test"}, + ... 2: {"category": "A", "text": "This is a second test"}, + ... 3: {"category": "B", "text": "This is a third test"}, + ... } + >>> dataset = DynamicItemDataset(data=sample_data) + >>> sampler = BalancingDataSampler( + ... dataset=dataset, key="category", num_samples=10 + ... ) + >>> sampler.weights + tensor([0.5000, 0.5000, 1.0000], dtype=torch.float64) + >>> it = iter(sampler) + >>> [next(it) for _ in range(10)] + [2, 2, 1, 2, 2, 0, 1, 1, 1, 2] + """ + + def __init__( + self, + dataset, + key, + num_samples=None, + replacement=True, + seed=563375142, + epoch=0, + **kwargs, + ): + self.dataset = dataset + self.key = key + if not num_samples: + num_samples = len(dataset) + weights = self._compute_weights() + super().__init__( + weights, num_samples, replacement, seed, epoch, **kwargs + ) + + def _compute_weights(self): + with self.dataset.output_keys_as([self.key]): + class_ids = [item[self.key] for item in self.dataset] + class_counter = Counter(class_ids) + weights = 1 / torch.tensor( + [class_counter[class_id] for class_id in class_ids] + ) + return weights diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/wer.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/wer.py new file mode 100644 index 000000000..dea945615 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/dataio/wer.py @@ -0,0 +1,201 @@ +"""WER print functions. + +The functions here are used to print the computed statistics +with human-readable formatting. +They have a file argument, but you can also just use +contextlib.redirect_stdout, which may give a nicer syntax. + +Authors + * Aku Rouhe 2020 +""" + +import sys + +from speechbrain.utils import edit_distance + + +def print_wer_summary(wer_details, file=sys.stdout): + """Prints out WER summary details in human-readable format. + + This function essentially mirrors the Kaldi compute-wer output format. + + Arguments + --------- + wer_details : dict + Dict of wer summary details, + see ``speechbrain.utils.edit_distance.wer_summary`` + for format. + file : stream + Where to write. (default: sys.stdout) + """ + print( + "%WER {WER:.2f} [ {num_edits} / {num_scored_tokens}, {insertions} ins, {deletions} del, {substitutions} sub ]".format( # noqa + **wer_details + ), + file=file, + end="", + ) + print( + ( + " [PARTIAL]" + if wer_details["num_scored_sents"] < wer_details["num_ref_sents"] + else "" + ), + file=file, + ) + print( + "%SER {SER:.2f} [ {num_erroneous_sents} / {num_scored_sents} ]".format( + **wer_details + ), + file=file, + ) + print( + "Scored {num_scored_sents} sentences, {num_absent_sents} not present in hyp.".format( # noqa + **wer_details + ), + file=file, + ) + + +def print_alignments( + details_by_utterance, + file=sys.stdout, + empty_symbol="", + separator=" ; ", + print_header=True, + sample_separator=None, +): + """Print WER summary and alignments. + + Arguments + --------- + details_by_utterance : list + List of wer details by utterance, + see ``speechbrain.utils.edit_distance.wer_details_by_utterance`` + for format. Has to have alignments included. + file : stream + Where to write. (default: sys.stdout) + empty_symbol : str + Symbol to use when aligning to nothing. + separator : str + String that separates each token in the output. Note the spaces in the + default. + print_header: bool + Whether to print headers + sample_separator: str + A separator to put between samples (optional) + """ + if print_header: + _print_alignments_global_header( + file=file, empty_symbol=empty_symbol, separator=separator + ) + for dets in details_by_utterance: + if dets["scored"]: + if print_header: + _print_alignment_header(dets, file=file) + _print_alignment( + dets["alignment"], + dets["ref_tokens"], + dets["hyp_tokens"], + file=file, + empty_symbol=empty_symbol, + separator=separator, + ) + if sample_separator: + print(sample_separator, file=file) + + +# The following internal functions are used to +# print out more specific things +def _print_top_wer_utts(top_non_empty, top_empty, file=sys.stdout): + print("=" * 80, file=file) + print("UTTERANCES WITH HIGHEST WER", file=file) + if top_non_empty: + print( + "Non-empty hypotheses -- utterances for which output was produced:", + file=file, + ) + for dets in top_non_empty: + print("{key} %WER {WER:.2f}".format(**dets), file=file) + else: + print("No utterances which had produced output!", file=file) + if top_empty: + print( + "Empty hypotheses -- utterances for which no output was produced:", + file=file, + ) + for dets in top_empty: + print("{key} %WER {WER:.2f}".format(**dets), file=file) + else: + print("No utterances which had not produced output!", file=file) + + +def _print_top_wer_spks(spks_by_wer, file=sys.stdout): + print("=" * 80, file=file) + print("SPEAKERS WITH HIGHEST WER", file=file) + for dets in spks_by_wer: + print("{speaker} %WER {WER:.2f}".format(**dets), file=file) + + +def _print_alignment( + alignment, a, b, empty_symbol="", separator=" ; ", file=sys.stdout +): + # First, get equal length text for all: + a_padded = [] + b_padded = [] + ops_padded = [] + for op, i, j in alignment: # i indexes a, j indexes b + op_string = str(op) + a_string = str(a[i]) if i is not None else empty_symbol + b_string = str(b[j]) if j is not None else empty_symbol + # NOTE: the padding does not actually compute printed length, + # but hopefully we can assume that printed length is + # at most the str len + pad_length = max(len(op_string), len(a_string), len(b_string)) + a_padded.append(a_string.center(pad_length)) + b_padded.append(b_string.center(pad_length)) + ops_padded.append(op_string.center(pad_length)) + # Then print, in the order Ref, op, Hyp + print(separator.join(a_padded), file=file) + print(separator.join(ops_padded), file=file) + print(separator.join(b_padded), file=file) + + +def _print_alignments_global_header( + empty_symbol="", separator=" ; ", file=sys.stdout +): + print("=" * 80, file=file) + print("ALIGNMENTS", file=file) + print("", file=file) + print("Format:", file=file) + print(", WER DETAILS", file=file) + # Print the format with the actual + # print_alignment function, using artificial data: + a = ["reference", "on", "the", "first", "line"] + b = ["and", "hypothesis", "on", "the", "third"] + alignment = [ + (edit_distance.EDIT_SYMBOLS["ins"], None, 0), + (edit_distance.EDIT_SYMBOLS["sub"], 0, 1), + (edit_distance.EDIT_SYMBOLS["eq"], 1, 2), + (edit_distance.EDIT_SYMBOLS["eq"], 2, 3), + (edit_distance.EDIT_SYMBOLS["sub"], 3, 4), + (edit_distance.EDIT_SYMBOLS["del"], 4, None), + ] + _print_alignment( + alignment, + a, + b, + file=file, + empty_symbol=empty_symbol, + separator=separator, + ) + + +def _print_alignment_header(wer_details, file=sys.stdout): + print("=" * 80, file=file) + print( + "{key}, %WER {WER:.2f} [ {num_edits} / {num_ref_tokens}, {insertions} ins, {deletions} del, {substitutions} sub ]".format( # noqa + **wer_details + ), + file=file, + ) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/__init__.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/__init__.py new file mode 100644 index 000000000..87014efdb --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/__init__.py @@ -0,0 +1,6 @@ +"""Package containing the different decoders (ctc, beamsearch ...)""" + +from .ctc import * # noqa +from .scorer import * # noqa +from .seq2seq import * # noqa +from .transducer import * # noqa diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/ctc.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/ctc.py new file mode 100644 index 000000000..ecaf689cc --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/ctc.py @@ -0,0 +1,1905 @@ +"""Decoders and output normalization for CTC. + +Authors + * Mirco Ravanelli 2020 + * Aku Rouhe 2020 + * Sung-Lin Yeh 2020 + * Adel Moumen 2023, 2024 +""" + +import dataclasses +import heapq +import math +import warnings +from itertools import groupby +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from speechbrain.dataio.dataio import length_to_mask +from speechbrain.utils.logger import get_logger + +logger = get_logger(__name__) + + +class CTCPrefixScore: + """This class implements the CTC prefix score of Algorithm 2 in + reference: https://www.merl.com/publications/docs/TR2017-190.pdf. + Official implementation: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py + + Arguments + --------- + x : torch.Tensor + The encoder states. + enc_lens : torch.Tensor + The actual length of each enc_states sequence. + blank_index : int + The index of the blank token. + eos_index : int + The index of the end-of-sequence (eos) token. + ctc_window_size: int + Compute the ctc scores over the time frames using windowing based on attention peaks. + If 0, no windowing applied. + """ + + def __init__(self, x, enc_lens, blank_index, eos_index, ctc_window_size=0): + self.blank_index = blank_index + self.eos_index = eos_index + self.batch_size = x.size(0) + self.max_enc_len = x.size(1) + self.vocab_size = x.size(-1) + self.device = x.device + self.minus_inf = -1e20 + self.last_frame_index = enc_lens - 1 + self.ctc_window_size = ctc_window_size + self.prefix_length = -1 + + # mask frames > enc_lens + mask = 1 - length_to_mask(enc_lens) + mask = mask.unsqueeze(-1).expand(-1, -1, x.size(-1)).eq(1) + x.masked_fill_(mask, self.minus_inf) + x[:, :, 0] = x[:, :, 0].masked_fill_(mask[:, :, 0], 0) + + # dim=0: xnb, nonblank posteriors, dim=1: xb, blank posteriors + xnb = x.transpose(0, 1) + xb = ( + xnb[:, :, self.blank_index] + .unsqueeze(2) + .expand(-1, -1, self.vocab_size) + ) + + # (2, L, batch_size * beam_size, vocab_size) + self.x = torch.stack([xnb, xb]) + + # indices of batch. + self.batch_index = torch.arange(self.batch_size, device=self.device) + + @torch.no_grad() + def forward_step(self, inp_tokens, states, candidates=None, attn=None): + """This method if one step of forwarding operation + for the prefix ctc scorer. + + Arguments + --------- + inp_tokens : torch.Tensor + The last chars of prefix label sequences g, where h = g + c. + states : tuple + Previous ctc states. + candidates : torch.Tensor + (batch_size * beam_size, ctc_beam_size), The topk candidates for rescoring. + If given, performing partial ctc scoring. + attn : torch.Tensor + (batch_size * beam_size, max_enc_len), The attention weights. + + Returns + ------- + new_psi : torch.Tensor + (r, psi, scoring_table) : tuple + """ + + n_bh = inp_tokens.size(0) + beam_size = n_bh // self.batch_size + last_char = inp_tokens + self.prefix_length += 1 + self.num_candidates = ( + self.vocab_size if candidates is None else candidates.size(-1) + ) + if states is None: + # r_prev: (L, 2, batch_size * beam_size) + r_prev = torch.full( + (self.max_enc_len, 2, self.batch_size, beam_size), + self.minus_inf, + device=self.device, + ) + + # Accumulate blank posteriors at each step + r_prev[:, 1] = torch.cumsum( + self.x[0, :, :, self.blank_index], 0 + ).unsqueeze(2) + r_prev = r_prev.view(-1, 2, n_bh) + psi_prev = torch.full( + (n_bh, self.vocab_size), 0.0, device=self.device + ) + else: + r_prev, psi_prev = states + + # for partial search + if candidates is not None: + # The first index of each candidate. + cand_offset = self.batch_index * self.vocab_size + scoring_table = torch.full( + (n_bh, self.vocab_size), + -1, + dtype=torch.long, + device=self.device, + ) + # Assign indices of candidates to their positions in the table + col_index = torch.arange(n_bh, device=self.device).unsqueeze(1) + scoring_table[col_index, candidates] = torch.arange( + self.num_candidates, device=self.device + ) + # Select candidates indices for scoring + scoring_index = ( + candidates + + cand_offset.unsqueeze(1).repeat(1, beam_size).view(-1, 1) + ).view(-1) + x_inflate = torch.index_select( + self.x.view(2, -1, self.batch_size * self.vocab_size), + 2, + scoring_index, + ).view(2, -1, n_bh, self.num_candidates) + # for full search + else: + scoring_table = None + # Inflate x to (2, -1, batch_size * beam_size, num_candidates) + # It is used to compute forward probs in a batched way + x_inflate = ( + self.x.unsqueeze(3) + .repeat(1, 1, 1, beam_size, 1) + .view(2, -1, n_bh, self.num_candidates) + ) + + # Prepare forward probs + r = torch.full( + (self.max_enc_len, 2, n_bh, self.num_candidates), + self.minus_inf, + device=self.device, + ) + r.fill_(self.minus_inf) + + # (Alg.2-6) + if self.prefix_length == 0: + r[0, 0] = x_inflate[0, 0] + # (Alg.2-10): phi = prev_nonblank + prev_blank = r_t-1^nb(g) + r_t-1^b(g) + r_sum = torch.logsumexp(r_prev, 1) + phi = r_sum.unsqueeze(2).repeat(1, 1, self.num_candidates) + + # (Alg.2-10): if last token of prefix g in candidates, phi = prev_b + 0 + if candidates is not None: + for i in range(n_bh): + pos = scoring_table[i, last_char[i]] + if pos != -1: + phi[:, i, pos] = r_prev[:, 1, i] + else: + for i in range(n_bh): + phi[:, i, last_char[i]] = r_prev[:, 1, i] + + # Start, end frames for scoring (|g| < |h|). + # Scoring based on attn peak if ctc_window_size > 0 + if self.ctc_window_size == 0 or attn is None: + start = max(1, self.prefix_length) + end = self.max_enc_len + else: + _, attn_peak = torch.max(attn, dim=1) + max_frame = torch.max(attn_peak).item() + self.ctc_window_size + min_frame = torch.min(attn_peak).item() - self.ctc_window_size + start = max(max(1, self.prefix_length), int(min_frame)) + end = min(self.max_enc_len, int(max_frame)) + + # Compute forward prob log(r_t^nb(h)) and log(r_t^b(h)): + for t in range(start, end): + # (Alg.2-11): dim=0, p(h|cur step is nonblank) = [p(prev step=y) + phi] * p(c) + rnb_prev = r[t - 1, 0] + # (Alg.2-12): dim=1, p(h|cur step is blank) = [p(prev step is blank) + p(prev step is nonblank)] * p(blank) + rb_prev = r[t - 1, 1] + r_ = torch.stack([rnb_prev, phi[t - 1], rnb_prev, rb_prev]).view( + 2, 2, n_bh, self.num_candidates + ) + r[t] = torch.logsumexp(r_, 1) + x_inflate[:, t] + + # Compute the predix prob, psi + psi_init = r[start - 1, 0].unsqueeze(0) + # phi is prob at t-1 step, shift one frame and add it to the current prob p(c) + phix = torch.cat((phi[0].unsqueeze(0), phi[:-1]), dim=0) + x_inflate[0] + # (Alg.2-13): psi = psi + phi * p(c) + if candidates is not None: + psi = torch.full( + (n_bh, self.vocab_size), self.minus_inf, device=self.device + ) + psi_ = torch.logsumexp( + torch.cat((phix[start:end], psi_init), dim=0), dim=0 + ) + # only assign prob to candidates + for i in range(n_bh): + psi[i, candidates[i]] = psi_[i] + else: + psi = torch.logsumexp( + torch.cat((phix[start:end], psi_init), dim=0), dim=0 + ) + + # (Alg.2-3): if c = , psi = log(r_T^n(g) + r_T^b(g)), where T is the length of max frames + for i in range(n_bh): + psi[i, self.eos_index] = r_sum[ + self.last_frame_index[i // beam_size], i + ] + + if self.eos_index != self.blank_index: + # Exclude blank probs for joint scoring + psi[:, self.blank_index] = self.minus_inf + + return psi - psi_prev, (r, psi, scoring_table) + + def permute_mem(self, memory, index): + """This method permutes the CTC model memory + to synchronize the memory index with the current output. + + Arguments + --------- + memory : No limit + The memory variable to be permuted. + index : torch.Tensor + The index of the previous path. + + Return + ------ + The variable of the memory being permuted. + + """ + + r, psi, scoring_table = memory + + beam_size = index.size(1) + n_bh = self.batch_size * beam_size + + # The first index of each batch. + beam_offset = self.batch_index * beam_size + # The index of top-K vocab came from in (t-1) timesteps at batch * beam * vocab dimension. + cand_index = ( + index + beam_offset.unsqueeze(1).expand_as(index) * self.vocab_size + ).view(n_bh) + # synchronize forward prob + psi = torch.index_select(psi.view(-1), dim=0, index=cand_index) + psi = ( + psi.view(-1, 1) + .repeat(1, self.vocab_size) + .view(n_bh, self.vocab_size) + ) + # The index of top-K vocab came from in (t-1) timesteps at batch * beam dimension. + hyp_index = ( + torch.div(index, self.vocab_size, rounding_mode="floor") + + beam_offset.unsqueeze(1).expand_as(index) + ).view(n_bh) + # synchronize ctc states + if scoring_table is not None: + selected_vocab = (index % self.vocab_size).view(-1) + score_index = scoring_table[hyp_index, selected_vocab] + score_index[score_index == -1] = 0 + cand_index = score_index + hyp_index * self.num_candidates + + r = torch.index_select( + r.view(-1, 2, n_bh * self.num_candidates), dim=-1, index=cand_index + ) + r = r.view(-1, 2, n_bh) + + return r, psi + + +def filter_ctc_output(string_pred, blank_id=-1): + """Apply CTC output merge and filter rules. + + Removes the blank symbol and output repetitions. + + Arguments + --------- + string_pred : list + A list containing the output strings/ints predicted by the CTC system. + blank_id : int, string + The id of the blank. + + Returns + ------- + list + The output predicted by CTC without the blank symbol and + the repetitions. + + Example + ------- + >>> string_pred = ["a", "a", "blank", "b", "b", "blank", "c"] + >>> string_out = filter_ctc_output(string_pred, blank_id="blank") + >>> print(string_out) + ['a', 'b', 'c'] + """ + + if isinstance(string_pred, list): + # Filter the repetitions + string_out = [i[0] for i in groupby(string_pred)] + + # Filter the blank symbol + string_out = list(filter(lambda elem: elem != blank_id, string_out)) + else: + raise ValueError("filter_ctc_out can only filter python lists") + return string_out + + +def ctc_greedy_decode(probabilities, seq_lens, blank_id=-1): + """Greedy decode a batch of probabilities and apply CTC rules. + + Arguments + --------- + probabilities : torch.tensor + Output probabilities (or log-probabilities) from the network with shape + [batch, lengths, probabilities] + seq_lens : torch.tensor + Relative true sequence lengths (to deal with padded inputs), + the longest sequence has length 1.0, others a value between zero and one + shape [batch, lengths]. + blank_id : int, string + The blank symbol/index. Default: -1. If a negative number is given, + it is assumed to mean counting down from the maximum possible index, + so that -1 refers to the maximum possible index. + + Returns + ------- + list + Outputs as Python list of lists, with "ragged" dimensions; padding + has been removed. + + Example + ------- + >>> import torch + >>> probs = torch.tensor( + ... [[[0.3, 0.7], [0.0, 0.0]], [[0.2, 0.8], [0.9, 0.1]]] + ... ) + >>> lens = torch.tensor([0.51, 1.0]) + >>> blank_id = 0 + >>> ctc_greedy_decode(probs, lens, blank_id) + [[1], [1]] + """ + if isinstance(blank_id, int) and blank_id < 0: + blank_id = probabilities.shape[-1] + blank_id + batch_max_len = probabilities.shape[1] + batch_outputs = [] + for seq, seq_len in zip(probabilities, seq_lens): + actual_size = int(torch.round(seq_len * batch_max_len)) + scores, predictions = torch.max(seq.narrow(0, 0, actual_size), dim=1) + out = filter_ctc_output(predictions.tolist(), blank_id=blank_id) + batch_outputs.append(out) + return batch_outputs + + +@dataclasses.dataclass +class CTCBeam: + """This class handle the CTC beam information during decoding. + + Arguments + --------- + text : str + The current text of the beam. + full_text : str + The full text of the beam. + next_word : str + The next word to be added to the beam. + partial_word : str + The partial word being added to the beam. + last_token : str, optional + The last token of the beam. + last_token_index : int, optional + The index of the last token of the beam. + text_frames : List[Tuple[int, int]] + The start and end frame of the text. + partial_frames : Tuple[int, int] + The start and end frame of the partial word. + p : float + The probability of the beam. + p_b : float + The probability of the beam ending in a blank. + p_nb : float + The probability of the beam not ending in a blank. + n_p_b : float + The previous probability of the beam ending in a blank. + n_p_nb : float + The previous probability of the beam not ending in a blank. + score : float + The score of the beam (LM + CTC) + score_ctc : float + The CTC score computed. + + Example + ------- + >>> beam = CTCBeam( + ... text="", + ... full_text="", + ... next_word="", + ... partial_word="", + ... last_token=None, + ... last_token_index=None, + ... text_frames=[(0, 0)], + ... partial_frames=(0, 0), + ... p=-math.inf, + ... p_b=-math.inf, + ... p_nb=-math.inf, + ... n_p_b=-math.inf, + ... n_p_nb=-math.inf, + ... score=-math.inf, + ... score_ctc=-math.inf, + ... ) + """ + + text: str + full_text: str + next_word: str + partial_word: str + last_token: Optional[str] + last_token_index: Optional[int] + text_frames: List[Tuple[int, int]] + partial_frames: Tuple[int, int] + p: float = -math.inf + p_b: float = -math.inf + p_nb: float = -math.inf + n_p_b: float = -math.inf + n_p_nb: float = -math.inf + score: float = -math.inf + score_ctc: float = -math.inf + + @classmethod + def from_lm_beam(cls, lm_beam: "LMCTCBeam") -> "CTCBeam": + """Create a CTCBeam from a LMCTCBeam + + Arguments + --------- + lm_beam : LMCTCBeam + The LMCTCBeam to convert. + + Returns + ------- + CTCBeam + The CTCBeam converted. + """ + return CTCBeam( + text=lm_beam.text, + full_text=lm_beam.full_text, + next_word=lm_beam.next_word, + partial_word=lm_beam.partial_word, + last_token=lm_beam.last_token, + last_token_index=lm_beam.last_token_index, + text_frames=lm_beam.text_frames, + partial_frames=lm_beam.partial_frames, + p=lm_beam.p, + p_b=lm_beam.p_b, + p_nb=lm_beam.p_nb, + n_p_b=lm_beam.n_p_b, + n_p_nb=lm_beam.n_p_nb, + score=lm_beam.score, + score_ctc=lm_beam.score_ctc, + ) + + def step(self) -> None: + """Update the beam probabilities.""" + self.p_b, self.p_nb = self.n_p_b, self.n_p_nb + self.n_p_b = self.n_p_nb = -math.inf + self.score_ctc = np.logaddexp(self.p_b, self.p_nb) + self.score = self.score_ctc + + +@dataclasses.dataclass +class LMCTCBeam(CTCBeam): + """This class handle the LM scores during decoding. + + Arguments + --------- + lm_score: float + The LM score of the beam. + **kwargs + See CTCBeam for the other arguments. + """ + + lm_score: float = -math.inf + + +@dataclasses.dataclass +class CTCHypothesis: + """This class is a data handler over the generated hypotheses. + + This class is the default output of the CTC beam searchers. + + It can be re-used for other decoders if using + the beam searchers in an online fashion. + + Arguments + --------- + text : str + The text of the hypothesis. + last_lm_state : None + The last LM state of the hypothesis. + score : float + The score of the hypothesis. + lm_score : float + The LM score of the hypothesis. + text_frames : List[Tuple[str, Tuple[int, int]]], optional + The list of the text and the corresponding frames. + """ + + text: str + last_lm_state: None + score: float + lm_score: float + text_frames: Optional[list] = None + + +class CTCBaseSearcher(torch.nn.Module): + """CTCBaseSearcher class to be inherited by other + CTC beam searchers. + + This class provides the basic functionalities for + CTC beam search decoding. + + The space_token is required with a non-sentencepiece vocabulary list + if your transcription is expecting to contain spaces. + + Arguments + --------- + blank_index : int + The index of the blank token. + vocab_list : list + The list of the vocabulary tokens. + space_token : int, optional + The index of the space token. (default: -1) + kenlm_model_path : str, optional + The path to the kenlm model. Use .bin for a faster loading. + If None, no language model will be used. (default: None) + unigrams : list, optional + The list of known word unigrams. (default: None) + alpha : float + Weight for language model during shallow fusion. (default: 0.5) + beta : float + Weight for length score adjustment of during scoring. (default: 1.5) + unk_score_offset : float + Amount of log score offset for unknown tokens. (default: -10.0) + score_boundary : bool + Whether to have kenlm respect boundaries when scoring. (default: True) + beam_size : int, optional + The width of the beam. (default: 100) + beam_prune_logp : float, optional + The pruning threshold for the beam. (default: -10.0) + token_prune_min_logp : float, optional + The pruning threshold for the tokens. (default: -5.0) + prune_history : bool, optional + Whether to prune the history. (default: True) + Note: when using topk > 1, this should be set to False as + it is pruning a lot of beams. + blank_skip_threshold : float, optional + Skip frames if log_prob(blank) > log(blank_skip_threshold), to speed up decoding. + Note: This is only used when using the CUDA decoder, and it might worsen the WER/CER results. Use it at your own risk. (default: 1.0) + topk : int, optional + The number of top hypotheses to return. (default: 1) + spm_token: str, optional + The sentencepiece token. (default: "▁") + + Example + ------- + >>> blank_index = 0 + >>> vocab_list = ["blank", "a", "b", "c", " "] + >>> space_token = " " + >>> kenlm_model_path = None + >>> unigrams = None + >>> beam_size = 100 + >>> beam_prune_logp = -10.0 + >>> token_prune_min_logp = -5.0 + >>> prune_history = True + >>> blank_skip_threshold = 1.0 + >>> topk = 1 + >>> searcher = CTCBaseSearcher( + ... blank_index=blank_index, + ... vocab_list=vocab_list, + ... space_token=space_token, + ... kenlm_model_path=kenlm_model_path, + ... unigrams=unigrams, + ... beam_size=beam_size, + ... beam_prune_logp=beam_prune_logp, + ... token_prune_min_logp=token_prune_min_logp, + ... prune_history=prune_history, + ... blank_skip_threshold=blank_skip_threshold, + ... topk=topk, + ... ) + """ + + def __init__( + self, + blank_index: int, + vocab_list: List[str], + space_token: str = " ", + kenlm_model_path: Union[None, str] = None, + unigrams: Union[None, list[str], set[str]] = None, + alpha: float = 0.5, + beta: float = 1.5, + unk_score_offset: float = -10.0, + score_boundary: bool = True, + beam_size: int = 100, + beam_prune_logp: float = -10.0, + token_prune_min_logp: float = -5.0, + prune_history: bool = True, + blank_skip_threshold: float = 1.0, + topk: int = 1, + spm_token: str = "▁", + ): + super().__init__() + + self.blank_index = blank_index + self.vocab_list = vocab_list + self.space_token = space_token + self.kenlm_model_path = kenlm_model_path + self.unigrams = unigrams + self.alpha = alpha + self.beta = beta + self.unk_score_offset = unk_score_offset + self.score_boundary = score_boundary + self.beam_size = beam_size + self.beam_prune_logp = beam_prune_logp + self.token_prune_min_logp = token_prune_min_logp + self.prune_history = prune_history + self.blank_skip_threshold = math.log(blank_skip_threshold) + self.topk = topk + self.spm_token = spm_token + + # check if the vocab is coming from SentencePiece + self.is_spm = any( + [str(s).startswith(self.spm_token) for s in vocab_list] + ) + + # fetch the index of space_token + if not self.is_spm: + try: + self.space_index = vocab_list.index(space_token) + except ValueError: + logger.warning( + f"space_token `{space_token}` not found in the vocabulary." + "Using value -1 as `space_index`." + "Note: If your transcription is not expected to contain spaces, " + "you can ignore this warning." + ) + self.space_index = -1 + logger.info(f"Found `space_token` at index {self.space_index}.") + + self.kenlm_model = None + if kenlm_model_path is not None: + try: + import kenlm # type: ignore + + from speechbrain.integrations.decoders.kenlm_scorer import ( + KenlmScorer, + load_unigram_set_from_arpa, + ) + except ImportError: + raise ImportError( + "kenlm python bindings are not installed. To install it use: " + "pip install https://github.com/kpu/kenlm/archive/master.zip" + ) + + self.kenlm_model = kenlm.Model(kenlm_model_path) + + if kenlm_model_path is not None and kenlm_model_path.endswith(".arpa"): + logger.info( + "Using arpa instead of binary LM file, decoder instantiation might be slow." + ) + + if unigrams is None and kenlm_model_path is not None: + if kenlm_model_path.endswith(".arpa"): + unigrams = load_unigram_set_from_arpa(kenlm_model_path) + else: + logger.warning( + "Unigrams not provided and cannot be automatically determined from LM file (only " + "arpa format). Decoding accuracy might be reduced." + ) + + if self.kenlm_model is not None: + self.lm = KenlmScorer( + kenlm_model=self.kenlm_model, + unigrams=unigrams, + alpha=self.alpha, + beta=self.beta, + unk_score_offset=self.unk_score_offset, + score_boundary=self.score_boundary, + ) + else: + self.lm = None + + def partial_decoding( + self, + log_probs: torch.Tensor, + beams: List[CTCBeam], + cached_lm_scores: dict, + cached_p_lm_scores: dict, + processed_frames: int = 0, + ): + """Perform a single step of decoding. + + Arguments + --------- + log_probs : torch.Tensor + The log probabilities of the CTC output. + beams : list + The list of the beams. + cached_lm_scores : dict + The cached language model scores. + cached_p_lm_scores : dict + The cached prefix language model scores. + processed_frames : int, default: 0 + The start frame of the current decoding step. + """ + raise NotImplementedError + + def normalize_whitespace(self, text: str) -> str: + """Efficiently normalize whitespace. + + Arguments + --------- + text : str + The text to normalize. + + Returns + ------- + str + The normalized text. + """ + return " ".join(text.split()) + + def merge_tokens(self, token_1: str, token_2: str) -> str: + """Merge two tokens, and avoid empty ones. + + Taken from: https://github.com/kensho-technologies/pyctcdecode + + Arguments + --------- + token_1 : str + The first token. + token_2 : str + The second token. + + Returns + ------- + str + The merged token. + """ + if len(token_2) == 0: + text = token_1 + elif len(token_1) == 0: + text = token_2 + else: + text = token_1 + " " + token_2 + return text + + def merge_beams(self, beams: List[CTCBeam]) -> List[CTCBeam]: + """Merge beams with the same text. + + Taken from: https://github.com/kensho-technologies/pyctcdecode + + Arguments + --------- + beams : list + The list of the beams. + + Returns + ------- + list + The list of CTCBeam merged. + """ + beam_dict = {} + for beam in beams: + new_text = self.merge_tokens(beam.text, beam.next_word) + hash_idx = (new_text, beam.partial_word, beam.last_token) + if hash_idx not in beam_dict: + beam_dict[hash_idx] = beam + else: + # We've already seen this text - we want to combine the scores + beam_dict[hash_idx] = dataclasses.replace( + beam, + score=np.logaddexp(beam_dict[hash_idx].score, beam.score), + ) + return list(beam_dict.values()) + + def sort_beams(self, beams: List[CTCBeam]) -> List[CTCBeam]: + """Sort beams by lm_score. + + Arguments + --------- + beams : list + The list of CTCBeam. + + Returns + ------- + list + The list of CTCBeam sorted. + """ + return heapq.nlargest(self.beam_size, beams, key=lambda x: x.lm_score) + + def _prune_history( + self, beams: List[CTCBeam], lm_order: int + ) -> List[CTCBeam]: + """Filter out beams that are the same over max_ngram history. + + Since n-gram language models have a finite history when scoring a new token, we can use that + fact to prune beams that only differ early on (more than n tokens in the past) and keep only the + higher scoring ones. Note that this helps speed up the decoding process but comes at the cost of + some amount of beam diversity. If more than the top beam is used in the output it should + potentially be disabled. + + Taken from: https://github.com/kensho-technologies/pyctcdecode + + Arguments + --------- + beams : list + The list of the beams. + lm_order : int + The order of the language model. + + Returns + ------- + list + The list of CTCBeam. + """ + # let's keep at least 1 word of history + min_n_history = max(1, lm_order - 1) + seen_hashes = set() + filtered_beams = [] + # for each beam after this, check if we need to add it + for lm_beam in beams: + # hash based on history that can still affect lm scoring going forward + hash_idx = ( + tuple(lm_beam.text.split()[-min_n_history:]), + lm_beam.partial_word, + lm_beam.last_token, + ) + if hash_idx not in seen_hashes: + filtered_beams.append(CTCBeam.from_lm_beam(lm_beam)) + seen_hashes.add(hash_idx) + return filtered_beams + + def finalize_decoding( + self, + beams: List[CTCBeam], + cached_lm_scores: dict, + cached_p_lm_scores: dict, + force_next_word=False, + is_end=False, + ) -> List[CTCBeam]: + """Finalize the decoding process by adding and scoring the last partial word. + + Arguments + --------- + beams : list + The list of CTCBeam. + cached_lm_scores : dict + The cached language model scores. + cached_p_lm_scores : dict + The cached prefix language model scores. + force_next_word : bool, default: False + Whether to force the next word. + is_end : bool, default: False + Whether the end of the sequence has been reached. + + Returns + ------- + list + The list of the CTCBeam. + """ + if force_next_word or is_end: + new_beams = [] + for beam in beams: + new_token_times = ( + beam.text_frames + if beam.partial_word == "" + else beam.text_frames + [beam.partial_frames] + ) + new_beams.append( + CTCBeam( + text=beam.text, + full_text=beam.full_text, + next_word=beam.partial_word, + partial_word="", + last_token=None, + last_token_index=None, + text_frames=new_token_times, + partial_frames=(-1, -1), + score=beam.score, + ) + ) + + new_beams = self.merge_beams(new_beams) + else: + new_beams = list(beams) + + scored_beams = self.get_lm_beams( + new_beams, cached_lm_scores, cached_p_lm_scores + ) + # remove beam outliers + max_score = max([b.lm_score for b in scored_beams]) + scored_beams = [ + b + for b in scored_beams + if b.lm_score >= max_score + self.beam_prune_logp + ] + + sorted_beams = self.sort_beams(scored_beams) + return sorted_beams + + def decode_beams( + self, + log_probs: torch.Tensor, + wav_lens: Optional[torch.Tensor] = None, + lm_start_state: Any = None, + ) -> List[List[CTCHypothesis]]: + """Decodes the input log probabilities of the CTC output. + + It automatically converts the SpeechBrain's relative length of the wav input + to the absolute length. + + Make sure that the input are in the log domain. The decoder will fail to decode + logits or probabilities. The input should be the log probabilities of the CTC output. + + Arguments + --------- + log_probs : torch.Tensor + The log probabilities of the CTC output. + The expected shape is [batch_size, seq_length, vocab_size]. + wav_lens : torch.Tensor, optional (default: None) + The SpeechBrain's relative length of the wav input. + lm_start_state : Any, optional (default: None) + The start state of the language model. + + Returns + ------- + list of list + The list of topk list of CTCHypothesis. + """ + # check that the last dimension of log_probs is equal to the vocab size + if log_probs.size(2) != len(self.vocab_list): + warnings.warn( + f"Vocab size mismatch: log_probs vocab dim is {log_probs.size(2)} " + f"while vocab_list is {len(self.vocab_list)}. " + "During decoding, going to truncate the log_probs vocab dim to match vocab_list." + ) + + # compute wav_lens and cast to numpy as it is faster + if wav_lens is not None: + wav_lens = log_probs.size(1) * wav_lens + wav_lens = wav_lens.cpu().numpy().astype(int) + else: + wav_lens = [log_probs.size(1)] * log_probs.size(0) + + log_probs = log_probs.cpu().numpy() + + hyps = [ + self.decode_log_probs(log_prob, wav_len, lm_start_state) + for log_prob, wav_len in zip(log_probs, wav_lens) + ] + return hyps + + def __call__( + self, + log_probs: torch.Tensor, + wav_lens: Optional[torch.Tensor] = None, + lm_start_state: Any = None, + ) -> List[List[CTCHypothesis]]: + """Decodes the log probabilities of the CTC output. + + It automatically converts the SpeechBrain's relative length of the wav input + to the absolute length. + + Each tensors is converted to numpy and CPU as it is faster and consumes less memory. + + Arguments + --------- + log_probs : torch.Tensor + The log probabilities of the CTC output. + The expected shape is [batch_size, seq_length, vocab_size]. + wav_lens : torch.Tensor, optional (default: None) + The SpeechBrain's relative length of the wav input. + lm_start_state : Any, optional (default: None) + The start state of the language model. + + Returns + ------- + list of list + The list of topk list of CTCHypothesis. + """ + return self.decode_beams(log_probs, wav_lens, lm_start_state) + + def partial_decode_beams( + self, + log_probs: torch.Tensor, + cached_lm_scores: dict, + cached_p_lm_scores: dict, + beams: List[CTCBeam], + processed_frames: int, + force_next_word=False, + is_end=False, + ) -> List[CTCBeam]: + """Perform a single step of decoding. + + Arguments + --------- + log_probs : torch.Tensor + The log probabilities of the CTC output. + cached_lm_scores : dict + The cached language model scores. + cached_p_lm_scores : dict + The cached prefix language model scores. + beams : list + The list of the beams. + processed_frames : int + The start frame of the current decoding step. + force_next_word : bool, optional (default: False) + Whether to force the next word. + is_end : bool, optional (default: False) + Whether the end of the sequence has been reached. + + Returns + ------- + list + The list of CTCBeam. + """ + beams = self.partial_decoding( + log_probs, + beams, + cached_lm_scores, + cached_p_lm_scores, + processed_frames=processed_frames, + ) + + trimmed_beams = self.finalize_decoding( + beams, + cached_lm_scores, + cached_p_lm_scores, + force_next_word=force_next_word, + is_end=is_end, + ) + + return trimmed_beams + + def decode_log_probs( + self, + log_probs: torch.Tensor, + wav_len: int, + lm_start_state: Optional[Any] = None, + ) -> List[CTCHypothesis]: + """Decodes the log probabilities of the CTC output. + + Arguments + --------- + log_probs : torch.Tensor + The log probabilities of the CTC output. + The expected shape is [seq_length, vocab_size]. + wav_len : int + The length of the wav input. + lm_start_state : Any, optional (default: None) + The start state of the language model. + + Returns + ------- + list + The topk list of CTCHypothesis. + """ + # prepare caching/state for language model + language_model = self.lm + if language_model is None: + cached_lm_scores = {} + else: + if lm_start_state is None: + start_state = language_model.get_start_state() + else: + start_state = lm_start_state + cached_lm_scores = {("", False): (0.0, start_state)} + cached_p_lm_scores: Dict[str, float] = {} + + beams = [ + CTCBeam( + text="", + full_text="", + next_word="", + partial_word="", + last_token=None, + last_token_index=None, + text_frames=[], + partial_frames=(-1, -1), + score=0.0, + score_ctc=0.0, + p_b=0.0, + ) + ] + + # loop over the frames and perform the decoding + beams = self.partial_decoding( + log_probs, wav_len, beams, cached_lm_scores, cached_p_lm_scores + ) + + # finalize decoding by adding and scoring the last partial word + trimmed_beams = self.finalize_decoding( + beams, + cached_lm_scores, + cached_p_lm_scores, + force_next_word=True, + is_end=True, + ) + + # transform the beams into hypotheses and select the topk + output_beams = [ + CTCHypothesis( + text=self.normalize_whitespace(lm_beam.text), + last_lm_state=( + cached_lm_scores[(lm_beam.text, True)][-1] + if (lm_beam.text, True) in cached_lm_scores + else None + ), + text_frames=list( + zip(lm_beam.text.split(), lm_beam.text_frames) + ), + score=lm_beam.score, + lm_score=lm_beam.lm_score, + ) + for lm_beam in trimmed_beams + ][: self.topk] + return output_beams + + +class CTCBeamSearcher(CTCBaseSearcher): + """CTC Beam Search is a Beam Search for CTC which does not keep track of + the blank and non-blank probabilities. Each new token probability is + added to the general score, and each beams that share the same text are + merged together. + + The implementation supports n-gram scoring on words and SentencePiece tokens. The input + is expected to be a log-probabilities tensor of shape [batch, time, vocab_size]. + + The main advantage of this CTCBeamSearcher over the CTCPrefixBeamSearcher is that it is + relatively faster, and obtains slightly better results. However, the implementation is + based on the one from the PyCTCDecode toolkit, adapted for the SpeechBrain's needs and does + not follow a specific paper. We do recommend to use the CTCPrefixBeamSearcher if you want + to cite the appropriate paper for the decoding method. + + Several heuristics are implemented to speed up the decoding process: + - pruning of the beam : the beams are pruned if their score is lower than + the best beam score minus the beam_prune_logp + - pruning of the tokens : the tokens are pruned if their score is lower than + the token_prune_min_logp + - pruning of the history : the beams are pruned if they are the same over + max_ngram history + - skipping of the blank : the frame is skipped if the blank probability is + higher than the blank_skip_threshold + + Note: if the Acoustic Model is not trained, the Beam Search will + take a lot of time. We do recommend to use Greedy Search during validation + until the model is fully trained and ready to be evaluated on test sets. + + Arguments + --------- + see CTCBaseSearcher, arguments are directly passed. + + Example + ------- + >>> import torch + >>> from speechbrain.decoders import CTCBeamSearcher + >>> probs = torch.tensor([[[0.2, 0.0, 0.8], [0.4, 0.0, 0.6]]]) + >>> log_probs = torch.log(probs) + >>> lens = torch.tensor([1.0]) + >>> blank_index = 2 + >>> vocab_list = ["a", "b", "-"] + >>> searcher = CTCBeamSearcher( + ... blank_index=blank_index, vocab_list=vocab_list + ... ) + >>> hyps = searcher(probs, lens) + """ + + def get_lm_beams( + self, + beams: List[CTCBeam], + cached_lm_scores: dict, + cached_partial_token_scores: dict, + is_eos=False, + ) -> List[LMCTCBeam]: + """Score the beams with the language model if not None, and + return the new beams. + + This function is modified and adapted from + https://github.com/kensho-technologies/pyctcdecode + + Arguments + --------- + beams : list + The list of the beams. + cached_lm_scores : dict + The cached language model scores. + cached_partial_token_scores : dict + The cached partial token scores. + is_eos : bool (default: False) + Whether the end of the sequence has been reached. + + Returns + ------- + new_beams : list + The list of the new beams. + """ + if self.lm is None: + # no lm is used, lm_score is equal to score and we can return the beams + new_beams = [] + for beam in beams: + new_text = self.merge_tokens(beam.text, beam.next_word) + new_beams.append( + LMCTCBeam( + text=new_text, + full_text=beam.full_text, + next_word="", + partial_word=beam.partial_word, + last_token=beam.last_token, + last_token_index=beam.last_token, + text_frames=beam.text_frames, + partial_frames=beam.partial_frames, + score=beam.score, + lm_score=beam.score, + ) + ) + return new_beams + else: + # lm is used, we need to compute the lm_score + # first we compute the lm_score of the next word + # we check if the next word is in the cache + # if not, we compute the score and add it to the cache + new_beams = [] + for beam in beams: + # fast token merge + new_text = self.merge_tokens(beam.text, beam.next_word) + cache_key = (new_text, is_eos) + if cache_key not in cached_lm_scores: + prev_raw_lm_score, start_state = cached_lm_scores[ + (beam.text, False) + ] + score, end_state = self.lm.score( + start_state, beam.next_word, is_last_word=is_eos + ) + raw_lm_score = prev_raw_lm_score + score + cached_lm_scores[cache_key] = (raw_lm_score, end_state) + lm_score, _ = cached_lm_scores[cache_key] + + # we score the partial word + word_part = beam.partial_word + if len(word_part) > 0: + if word_part not in cached_partial_token_scores: + cached_partial_token_scores[word_part] = ( + self.lm.score_partial_token(word_part) + ) + lm_score += cached_partial_token_scores[word_part] + + new_beams.append( + LMCTCBeam( + text=new_text, + full_text=beam.full_text, + next_word="", + partial_word=word_part, + last_token=beam.last_token, + last_token_index=beam.last_token, + text_frames=beam.text_frames, + partial_frames=beam.partial_frames, + score=beam.score, + lm_score=beam.score + lm_score, + ) + ) + return new_beams + + def partial_decoding( + self, + log_probs: torch.Tensor, + wav_len: int, + beams: List[CTCBeam], + cached_lm_scores: dict, + cached_p_lm_scores: dict, + processed_frames: int = 0, + ) -> List[CTCBeam]: + """Perform CTC Prefix Beam Search decoding. + + If self.lm is not None, the language model scores are computed and added to the CTC scores. + + Arguments + --------- + log_probs : torch.Tensor + The log probabilities of the CTC input. + Shape: (seq_length, vocab_size) + wav_len : int + The length of the input sequence. + beams : list + The list of CTCBeam objects. + cached_lm_scores : dict + The cached language model scores. + cached_p_lm_scores : dict + The cached prefix language model scores. + processed_frames : int + The start frame of the current decoding step. (default: 0) + + Returns + ------- + beams : list + The list of CTCBeam objects. + """ + # select only the valid frames i.e. the frames that are not padded + log_probs = log_probs[:wav_len] + + for frame_index, logit_col in enumerate( + log_probs, start=processed_frames + ): + # skip the frame if the blank probability is higher than the threshold + if logit_col[self.blank_index] > self.blank_skip_threshold: + continue + + # get the tokens with the highest probability + max_index = logit_col.argmax() + tokens_index_list = set( + np.where(logit_col > self.token_prune_min_logp)[0] + ) | {max_index} + new_beams = [] + + # select tokens that are in the vocab + # this is useful if the logit vocab_size is larger than the vocab_list + tokens_index_list = tokens_index_list & set( + range(len(self.vocab_list)) + ) + + for token_index in tokens_index_list: + p_token = logit_col[token_index] + token = self.vocab_list[token_index] + + for beam in beams: + if ( + token_index == self.blank_index + or beam.last_token == token + ): + if token_index == self.blank_index: + new_end_frame = beam.partial_frames[0] + else: + new_end_frame = frame_index + 1 + + new_part_frames = ( + beam.partial_frames + if token_index == self.blank_index + else (beam.partial_frames[0], new_end_frame) + ) + + # if blank or repeated token, we only change the score + new_beams.append( + CTCBeam( + text=beam.text, + full_text=beam.full_text, + next_word=beam.next_word, + partial_word=beam.partial_word, + last_token=token, + last_token_index=token_index, + text_frames=beam.text_frames, + partial_frames=new_part_frames, + score=beam.score + p_token, + ) + ) + + elif self.is_spm and token[:1] == self.spm_token: + # remove the spm token at the beginning of the token + clean_token = token[1:] + + new_frame_list = ( + beam.text_frames + if beam.partial_word == "" + else beam.text_frames + [beam.partial_frames] + ) + + # If the beginning of the token is the spm_token + # then it means that we are extending the beam with a new word. + # We need to change the new_word with the partial_word + # and reset the partial_word with the new token + new_beams.append( + CTCBeam( + text=beam.text, + full_text=beam.full_text, + next_word=beam.partial_word, + partial_word=clean_token, + last_token=token, + last_token_index=token_index, + text_frames=new_frame_list, + partial_frames=(frame_index, frame_index + 1), + score=beam.score + p_token, + ) + ) + + elif not self.is_spm and token_index == self.space_index: + new_frame_list = ( + beam.text_frames + if beam.partial_word == "" + else beam.text_frames + [beam.partial_frames] + ) + + # same as before but in the case of a non spm vocab + new_beams.append( + CTCBeam( + text=beam.text, + full_text=beam.full_text, + next_word=beam.partial_word, + partial_word="", + last_token=token, + last_token_index=token_index, + text_frames=new_frame_list, + partial_frames=(-1, -1), + score=beam.score + p_token, + ) + ) + else: + new_part_frames = ( + (frame_index, frame_index + 1) + if beam.partial_frames[0] < 0 + else (beam.partial_frames[0], frame_index + 1) + ) + + # last case, we are extending the partial_word with a new token + new_beams.append( + CTCBeam( + text=beam.text, + full_text=beam.full_text, + next_word=beam.next_word, + partial_word=beam.partial_word + token, + last_token=token, + last_token_index=token_index, + text_frames=beam.text_frames, + partial_frames=new_part_frames, + score=beam.score + p_token, + ) + ) + + # we merge the beams with the same text + new_beams = self.merge_beams(new_beams) + + # kenlm scoring + scored_beams = self.get_lm_beams( + new_beams, cached_lm_scores, cached_p_lm_scores + ) + + # remove beam outliers + max_score = max([b.lm_score for b in scored_beams]) + scored_beams = [ + b + for b in scored_beams + if b.lm_score >= max_score + self.beam_prune_logp + ] + + trimmed_beams = self.sort_beams(scored_beams) + + if self.prune_history: + lm_order = 1 if self.lm is None else self.lm.order + beams = self._prune_history(trimmed_beams, lm_order=lm_order) + else: + beams = [CTCBeam.from_lm_beam(b) for b in trimmed_beams] + + return beams + + +class CTCPrefixBeamSearcher(CTCBaseSearcher): + """CTC Prefix Beam Search is based on the paper + `First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs` + by Awni Y. Hannun and al (https://arxiv.org/abs/1408.2873). + + The implementation keep tracks of the blank and non-blank probabilities. + It also supports n-gram scoring on words and SentencePiece tokens. The input + is expected to be a log-probabilities tensor of shape [batch, time, vocab_size]. + + Several heuristics are implemented to speed up the decoding process: + - pruning of the beam : the beams are pruned if their score is lower than + the best beam score minus the beam_prune_logp + - pruning of the tokens : the tokens are pruned if their score is lower than + the token_prune_min_logp + - pruning of the history : the beams are pruned if they are the same over + max_ngram history + - skipping of the blank : the frame is skipped if the blank probability is + higher than the blank_skip_threshold + + Note: The CTCPrefixBeamSearcher can be more unstable than the CTCBeamSearcher + or the TorchAudioCTCPrefixBeamSearch searcher. Please, use it with caution + and check the results carefully. + + Note: if the Acoustic Model is not trained, the Beam Search will + take a lot of time. We do recommend to use Greedy Search during validation + until the model is fully trained and ready to be evaluated on test sets. + + Note: This implementation does not provide the time alignment of the + hypothesis. If you need it, please use the CTCBeamSearcher. + + Arguments + --------- + see CTCBaseSearcher, arguments are directly passed. + + Example + ------- + >>> import torch + >>> from speechbrain.decoders import CTCPrefixBeamSearcher + >>> probs = torch.tensor([[[0.2, 0.0, 0.8], [0.4, 0.0, 0.6]]]) + >>> log_probs = torch.log(probs) + >>> lens = torch.tensor([1.0]) + >>> blank_index = 2 + >>> vocab_list = ["a", "b", "-"] + >>> searcher = CTCPrefixBeamSearcher( + ... blank_index=blank_index, vocab_list=vocab_list + ... ) + >>> hyps = searcher(probs, lens) + """ + + def get_lm_beams( + self, + beams: List[CTCBeam], + cached_lm_scores: dict, + cached_partial_token_scores: dict, + is_eos=False, + ) -> List[LMCTCBeam]: + """Score the beams with the language model if not None, and + return the new beams. + + This function is modified and adapted from + https://github.com/kensho-technologies/pyctcdecode + + Arguments + --------- + beams : list + The list of the beams. + cached_lm_scores : dict + The cached language model scores. + cached_partial_token_scores : dict + The cached partial token scores. + is_eos : bool (default: False) + Whether the end of the sequence has been reached. + + Returns + ------- + new_beams : list + The list of the new beams. + """ + if self.lm is None: + # no lm is used, lm_score is equal to score and we can return the beams + # we have to keep track of the probabilities as well + new_beams = [] + for beam in beams: + new_text = self.merge_tokens(beam.full_text, beam.next_word) + new_beams.append( + LMCTCBeam( + text=beam.text, + full_text=new_text, + next_word="", + partial_word=beam.partial_word, + last_token=beam.last_token, + last_token_index=beam.last_token_index, + text_frames=beam.text_frames, + partial_frames=beam.partial_frames, + p=beam.p, + p_b=beam.p_b, + p_nb=beam.p_nb, + n_p_b=beam.n_p_b, + n_p_nb=beam.n_p_nb, + score=beam.score, + score_ctc=beam.score_ctc, + lm_score=beam.score, + ) + ) + return new_beams + else: + # lm is used, we need to compute the lm_score + # first we compute the lm_score of the next word + # we check if the next word is in the cache + # if not, we compute the score and add it to the cache + new_beams = [] + for beam in beams: + # fast token merge + new_text = self.merge_tokens(beam.full_text, beam.next_word) + cache_key = (new_text, is_eos) + if cache_key not in cached_lm_scores: + prev_raw_lm_score, start_state = cached_lm_scores[ + (beam.full_text, False) + ] + score, end_state = self.lm.score( + start_state, beam.next_word, is_last_word=is_eos + ) + raw_lm_score = prev_raw_lm_score + score + cached_lm_scores[cache_key] = (raw_lm_score, end_state) + lm_score, _ = cached_lm_scores[cache_key] + word_part = beam.partial_word + + # we score the partial word + if len(word_part) > 0: + if word_part not in cached_partial_token_scores: + cached_partial_token_scores[word_part] = ( + self.lm.score_partial_token(word_part) + ) + lm_score += cached_partial_token_scores[word_part] + + new_beams.append( + LMCTCBeam( + text=beam.text, + full_text=new_text, + next_word="", + partial_word=beam.partial_word, + last_token=beam.last_token, + last_token_index=beam.last_token_index, + text_frames=beam.text_frames, + partial_frames=beam.partial_frames, + p=beam.p, + p_b=beam.p_b, + p_nb=beam.p_nb, + n_p_b=beam.n_p_b, + n_p_nb=beam.n_p_nb, + score=beam.score, + score_ctc=beam.score_ctc, + lm_score=beam.score + lm_score, + ) + ) + return new_beams + + def _get_new_beam( + self, + frame_index: int, + new_prefix: str, + new_token: str, + new_token_index: int, + beams: List[CTCBeam], + p: float, + previous_beam: CTCBeam, + ) -> CTCBeam: + """Create a new beam and add it to the list of beams. + + Arguments + --------- + frame_index : int + The index of the current frame. + new_prefix : str + The new prefix. + new_token : str + The new token. + new_token_index : int + The index of the new token. + beams : list + The list of beams. + p : float + The probability of the new token. + previous_beam : CTCBeam + The previous beam. + + Returns + ------- + new_beam : CTCBeam + The new beam. + """ + for beam in beams: + if beam.text == new_prefix: + if p and p > beam.p: + beam.p = p + return beam + + if not self.is_spm and new_token_index == self.space_index: + new_frame_list = ( + previous_beam.text_frames + if previous_beam.partial_word == "" + else previous_beam.text_frames + [previous_beam.partial_frames] + ) + + # if we extend the beam with a space, we need to reset the partial word + # and move it to the next word + new_beam = CTCBeam( + text=new_prefix, + full_text=previous_beam.full_text, + next_word=previous_beam.partial_word, + partial_word="", + last_token=new_token, + last_token_index=new_token_index, + text_frames=new_frame_list, + partial_frames=(-1, -1), + score=-math.inf, + score_ctc=-math.inf, + p_b=-math.inf, + ) + elif self.is_spm and new_token[:1] == self.spm_token: + # remove the spm token at the beginning of the token + clean_token = new_token[1:] + + new_frame_list = ( + previous_beam.text_frames + if previous_beam.partial_word == "" + else previous_beam.text_frames + [previous_beam.partial_frames] + ) + + # If the beginning of the token is the spm_token + # then it means that we are extending the beam with a new word. + # We need to change the new_word with the partial_word + # and reset the partial_word with the new token + new_prefix = previous_beam.text + " " + clean_token + new_beam = CTCBeam( + text=new_prefix, + full_text=previous_beam.full_text, + next_word=previous_beam.partial_word, + partial_word=clean_token, + last_token=new_token, + last_token_index=new_token_index, + text_frames=new_frame_list, + partial_frames=(frame_index, frame_index + 1), + score=-math.inf, + score_ctc=-math.inf, + p_b=-math.inf, + ) + elif new_token_index == previous_beam.last_token_index: + new_end_frame = frame_index + 1 + + new_part_frames = ( + previous_beam.partial_frames + if new_token_index == self.blank_index + else (previous_beam.partial_frames[0], new_end_frame) + ) + + # if repeated token, we only change the score + new_beam = CTCBeam( + text=new_prefix, + full_text=previous_beam.full_text, + next_word="", + partial_word=previous_beam.partial_word, + last_token=new_token, + last_token_index=new_token_index, + text_frames=previous_beam.text_frames, + partial_frames=new_part_frames, + score=-math.inf, + score_ctc=-math.inf, + p_b=-math.inf, + ) + else: + new_part_frames = ( + (frame_index, frame_index + 1) + if previous_beam.partial_frames[0] < 0 + else (previous_beam.partial_frames[0], frame_index + 1) + ) + + # last case, we are extending the partial_word with a new token + new_beam = CTCBeam( + text=new_prefix, + full_text=previous_beam.full_text, + next_word="", + partial_word=previous_beam.partial_word + new_token, + last_token=new_token, + last_token_index=new_token_index, + text_frames=previous_beam.text_frames, + partial_frames=new_part_frames, + score=-math.inf, + score_ctc=-math.inf, + p_b=-math.inf, + ) + beams.append(new_beam) + if previous_beam: + new_beam.p = previous_beam.p + return new_beam + + def partial_decoding( + self, + log_probs: torch.Tensor, + wav_len: int, + beams: List[CTCBeam], + cached_lm_scores: dict, + cached_p_lm_scores: dict, + processed_frames: int = 0, + ) -> List[CTCBeam]: + """Perform CTC Prefix Beam Search decoding. + + If self.lm is not None, the language model scores are computed and added to the CTC scores. + + Arguments + --------- + log_probs : torch.Tensor + The log probabilities of the CTC input. + Shape: (seq_length, vocab_size) + wav_len : int + The length of the input sequence. + beams : list + The list of CTCBeam objects. + cached_lm_scores : dict + The cached language model scores. + cached_p_lm_scores : dict + The cached prefix language model scores. + processed_frames : int + The start frame of the current decoding step. (default: 0) + + Returns + ------- + beams : list + The list of CTCBeam objects. + """ + # select only the valid frames, i.e., the frames that are not padded + log_probs = log_probs[:wav_len] + + for frame_index, logit_col in enumerate( + log_probs, start=processed_frames + ): + # skip the frame if the blank probability is higher than the threshold + if logit_col[self.blank_index] > self.blank_skip_threshold: + continue + + # get the tokens with the highest probability + max_index = logit_col.argmax() + tokens_index_list = set( + np.where(logit_col > self.token_prune_min_logp)[0] + ) | {max_index} + + curr_beams = beams.copy() + + # select tokens that are in the vocab + # this is useful if the logit vocab_size is larger than the vocab_list + tokens_index_list = tokens_index_list & set( + range(len(self.vocab_list)) + ) + + for token_index in tokens_index_list: + p_token = logit_col[token_index] + token = self.vocab_list[token_index] + + for beam in curr_beams: + p_b, p_nb = beam.p_b, beam.p_nb + + # blank case + if token_index == self.blank_index: + beam.n_p_b = float( + np.logaddexp(beam.n_p_b, beam.score_ctc + p_token) + ) + continue + + if token == beam.last_token: + beam.n_p_nb = float( + np.logaddexp(beam.n_p_nb, p_nb + p_token) + ) + + new_text = beam.text + token + + new_beam = self._get_new_beam( + frame_index, + new_text, + token, + token_index, + beams, + p=p_token, + previous_beam=beam, + ) + + n_p_nb = new_beam.n_p_nb + + if token_index == beam.last_token_index and p_b > -math.inf: + n_p_nb = np.logaddexp(n_p_nb, p_b + p_token) + elif token_index != beam.last_token_index: + n_p_nb = np.logaddexp(n_p_nb, beam.score_ctc + p_token) + new_beam.n_p_nb = float(n_p_nb) + + # update the CTC probabilities + for beam in beams: + beam.step() + + # kenLM scores + scored_beams = self.get_lm_beams( + beams, cached_lm_scores, cached_p_lm_scores + ) + + # remove beams outliers + max_score = max([b.lm_score for b in scored_beams]) + scored_beams = [ + b + for b in scored_beams + if b.lm_score >= max_score + self.beam_prune_logp + ] + trimmed_beams = self.sort_beams(scored_beams) + + if self.prune_history: + lm_order = 1 if self.lm is None else self.lm.order + beams = self._prune_history(trimmed_beams, lm_order=lm_order) + else: + beams = [CTCBeam.from_lm_beam(b) for b in trimmed_beams] + + return beams diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/language_model.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/language_model.py new file mode 100644 index 000000000..9b186e1de --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/language_model.py @@ -0,0 +1,11 @@ +"""This file ensures old links to this file continue to work while providing a Deprecation warning""" + +import warnings + +from speechbrain.integrations.decoders.kenlm_scorer import * # noqa: F401, F403 + +warnings.warn( + message="speechbrain.decoders.language_model has moved to speechbrain.integrations.decoders.kenlm_scorer", + category=DeprecationWarning, + stacklevel=2, +) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/scorer.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/scorer.py new file mode 100644 index 000000000..c3b1a88e4 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/scorer.py @@ -0,0 +1,2189 @@ +""" +Token scorer abstraction and specifications. + +Authors: + * Adel Moumen 2022, 2023 + * Sung-Lin Yeh 2021 +""" + +import numpy as np +import torch + +import speechbrain as sb +from speechbrain.decoders.ctc import CTCPrefixScore + + +class BaseScorerInterface: + """A scorer abstraction to be inherited by other + scoring approaches for beam search. + + A scorer is a module that scores tokens in vocabulary + based on the current timestep input and the previous + scorer states. It can be used to score on full vocabulary + set (i.e., full scorers) or a pruned set of tokens (i.e. partial scorers) + to prevent computation overhead. In the latter case, the partial scorers + will be called after the full scorers. It will only scores the + top-k candidates (i.e., pruned set of tokens) extracted from the full scorers. + The top-k candidates are extracted based on the beam size and the + scorer_beam_scale such that the number of candidates is + int(beam_size * scorer_beam_scale). It can be very useful + when the full scorers are computationally expensive (e.g., KenLM scorer). + + Inherit this class to implement your own scorer compatible with + speechbrain.decoders.seq2seq.S2SBeamSearcher(). + + See: + - speechbrain.decoders.scorer.CTCPrefixScorer + - speechbrain.decoders.scorer.RNNLMScorer + - speechbrain.decoders.scorer.TransformerLMScorer + - speechbrain.decoders.scorer.KenLMScorer + - speechbrain.decoders.scorer.CoverageScorer + - speechbrain.decoders.scorer.LengthScorer + """ + + def score(self, inp_tokens, memory, candidates, attn): + """This method scores the new beams based on the + information of the current timestep. + + A score is a tensor of shape (batch_size x beam_size, vocab_size). + It is the log probability of the next token given the current + timestep input and the previous scorer states. + + It can be used to score on pruned top-k candidates + to prevent computation overhead, or on full vocabulary set + when candidates is None. + + Arguments + --------- + inp_tokens : torch.Tensor + The input tensor of the current timestep. + memory : No limit + The scorer states for this timestep. + candidates : torch.Tensor + (batch_size x beam_size, scorer_beam_size). + The top-k candidates to be scored after the full scorers. + If None, scorers will score on full vocabulary set. + attn : torch.Tensor + The attention weight to be used in CoverageScorer or CTCScorer. + + Returns + ------- + torch.Tensor + (batch_size x beam_size, vocab_size), Scores for the next tokens. + memory : No limit + The memory variables input for this timestep. + """ + raise NotImplementedError + return + + def permute_mem(self, memory, index): + """This method permutes the scorer memory to synchronize + the memory index with the current output and perform + batched beam search. + + Arguments + --------- + memory : No limit + The memory variables input for this timestep. + index : torch.Tensor + (batch_size, beam_size). The index of the previous path. + """ + pass + + def reset_mem(self, x, enc_lens): + """This method should implement the resetting of + memory variables for the scorer. + + Arguments + --------- + x : torch.Tensor + The precomputed encoder states to be used when decoding. + (ex. the encoded speech representation to be attended). + enc_lens : torch.Tensor + The speechbrain-style relative length. + """ + pass + + +class CTCScorer(BaseScorerInterface): + """A wrapper of CTCPrefixScore based on the BaseScorerInterface. + + This Scorer is used to provides the CTC label-synchronous scores + of the next input tokens. The implementation is based on + https://www.merl.com/publications/docs/TR2017-190.pdf. + + See: + - speechbrain.decoders.scorer.CTCPrefixScore + + Arguments + --------- + ctc_fc : torch.nn.Module + A output linear layer for ctc. + blank_index : int + The index of the blank token. + eos_index : int + The index of the end-of-sequence (eos) token. + ctc_window_size : int + Compute the ctc scores over the time frames using windowing + based on attention peaks. If 0, no windowing applied. (default: 0) + + Example + ------- + >>> import torch + >>> from speechbrain.nnet.linear import Linear + >>> from speechbrain.lobes.models.transformer.TransformerASR import ( + ... TransformerASR, + ... ) + >>> from speechbrain.decoders import ( + ... S2STransformerBeamSearcher, + ... CTCScorer, + ... ScorerBuilder, + ... ) + >>> batch_size = 8 + >>> n_channels = 6 + >>> input_size = 40 + >>> d_model = 128 + >>> tgt_vocab = 140 + >>> src = torch.rand([batch_size, n_channels, input_size]) + >>> tgt = torch.randint(0, tgt_vocab, [batch_size, n_channels]) + >>> net = TransformerASR( + ... tgt_vocab, + ... input_size, + ... d_model, + ... 8, + ... 1, + ... 1, + ... 1024, + ... activation=torch.nn.GELU, + ... ) + >>> ctc_lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab) + >>> lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab) + >>> eos_index = 2 + >>> ctc_scorer = CTCScorer( + ... ctc_fc=ctc_lin, + ... blank_index=0, + ... eos_index=eos_index, + ... ) + >>> scorer = ScorerBuilder(full_scorers=[ctc_scorer], weights={"ctc": 1.0}) + >>> searcher = S2STransformerBeamSearcher( + ... modules=[net, lin], + ... bos_index=1, + ... eos_index=eos_index, + ... min_decode_ratio=0.0, + ... max_decode_ratio=1.0, + ... using_eos_threshold=False, + ... beam_size=7, + ... temperature=1.15, + ... scorer=scorer, + ... ) + >>> enc, dec = net.forward(src, tgt) + >>> hyps, _, _, _ = searcher(enc, torch.ones(batch_size)) + """ + + def __init__(self, ctc_fc, blank_index, eos_index, ctc_window_size=0): + self.ctc_fc = ctc_fc + self.blank_index = blank_index + self.eos_index = eos_index + self.ctc_window_size = ctc_window_size + self.softmax = sb.nnet.activations.Softmax(apply_log=True) + + def score(self, inp_tokens, memory, candidates, attn): + """This method scores the new beams based on the + CTC scores computed over the time frames. + + See: + - speechbrain.decoders.scorer.CTCPrefixScore + + Arguments + --------- + inp_tokens : torch.Tensor + The input tensor of the current timestep. + memory : No limit + The scorer states for this timestep. + candidates : torch.Tensor + (batch_size x beam_size, scorer_beam_size). + The top-k candidates to be scored after the full scorers. + If None, scorers will score on full vocabulary set. + attn : torch.Tensor + The attention weight to be used in CoverageScorer or CTCScorer. + + Returns + ------- + scores : torch.Tensor + memory + """ + scores, memory = self.ctc_score.forward_step( + inp_tokens, memory, candidates, attn + ) + return scores, memory + + def permute_mem(self, memory, index): + """This method permutes the scorer memory to synchronize + the memory index with the current output and perform + batched CTC beam search. + + Arguments + --------- + memory : No limit + The memory variables input for this timestep. + index : torch.Tensor + (batch_size, beam_size). The index of the previous path. + + Returns + ------- + r, psi : see ``ctc_score.permute_mem`` + """ + r, psi = self.ctc_score.permute_mem(memory, index) + return r, psi + + def reset_mem(self, x, enc_lens): + """This method implement the resetting of + memory variables for the CTC scorer. + + Arguments + --------- + x : torch.Tensor + The precomputed encoder states to be used when decoding. + (ex. the encoded speech representation to be attended). + enc_lens : torch.Tensor + The speechbrain-style relative length. + """ + logits = self.ctc_fc(x) + x = self.softmax(logits) + self.ctc_score = CTCPrefixScore( + x, enc_lens, self.blank_index, self.eos_index, self.ctc_window_size + ) + + +class RNNLMScorer(BaseScorerInterface): + """A wrapper of RNNLM based on BaseScorerInterface. + + The RNNLMScorer is used to provide the RNNLM scores of the next input tokens + based on the current timestep input and the previous scorer states. + + Arguments + --------- + language_model : torch.nn.Module + A RNN-based language model. + temperature : float + Temperature factor applied to softmax. It changes the probability + distribution, being softer when T>1 and sharper with T<1. (default: 1.0) + + Example + ------- + >>> from speechbrain.nnet.linear import Linear + >>> from speechbrain.lobes.models.RNNLM import RNNLM + >>> from speechbrain.nnet.RNN import AttentionalRNNDecoder + >>> from speechbrain.decoders import ( + ... S2SRNNBeamSearcher, + ... RNNLMScorer, + ... ScorerBuilder, + ... ) + >>> input_size = 17 + >>> vocab_size = 11 + >>> emb = torch.nn.Embedding( + ... embedding_dim=input_size, + ... num_embeddings=vocab_size, + ... ) + >>> d_model = 7 + >>> dec = AttentionalRNNDecoder( + ... rnn_type="gru", + ... attn_type="content", + ... hidden_size=3, + ... attn_dim=3, + ... num_layers=1, + ... enc_dim=d_model, + ... input_size=input_size, + ... ) + >>> n_channels = 3 + >>> seq_lin = Linear( + ... input_shape=[d_model, n_channels], n_neurons=vocab_size + ... ) + >>> lm_weight = 0.4 + >>> lm_model = RNNLM( + ... embedding_dim=d_model, + ... output_neurons=vocab_size, + ... dropout=0.0, + ... rnn_neurons=128, + ... dnn_neurons=64, + ... return_hidden=True, + ... ) + >>> rnnlm_scorer = RNNLMScorer( + ... language_model=lm_model, + ... temperature=1.25, + ... ) + >>> scorer = ScorerBuilder( + ... full_scorers=[rnnlm_scorer], weights={"rnnlm": lm_weight} + ... ) + >>> beam_size = 5 + >>> searcher = S2SRNNBeamSearcher( + ... embedding=emb, + ... decoder=dec, + ... linear=seq_lin, + ... bos_index=1, + ... eos_index=2, + ... min_decode_ratio=0.0, + ... max_decode_ratio=1.0, + ... topk=2, + ... using_eos_threshold=False, + ... beam_size=beam_size, + ... temperature=1.25, + ... scorer=scorer, + ... ) + >>> batch_size = 2 + >>> enc = torch.rand([batch_size, n_channels, d_model]) + >>> wav_len = torch.ones([batch_size]) + >>> hyps, _, _, _ = searcher(enc, wav_len) + """ + + def __init__(self, language_model, temperature=1.0): + self.lm = language_model + self.lm.eval() + self.temperature = temperature + self.softmax = sb.nnet.activations.Softmax(apply_log=True) + + def score(self, inp_tokens, memory, candidates, attn): + """This method scores the new beams based on the + RNNLM scores computed over the previous tokens. + + Arguments + --------- + inp_tokens : torch.Tensor + The input tensor of the current timestep. + memory : No limit + The scorer states for this timestep. + candidates : torch.Tensor + (batch_size x beam_size, scorer_beam_size). + The top-k candidates to be scored after the full scorers. + If None, scorers will score on full vocabulary set. + attn : torch.Tensor + The attention weight to be used in CoverageScorer or CTCScorer. + + Returns + ------- + log_probs : torch.Tensor + Output probabilities. + hs : torch.Tensor + LM hidden states. + """ + with torch.no_grad(): + logits, hs = self.lm(inp_tokens, hx=memory) + log_probs = self.softmax(logits / self.temperature) + return log_probs, hs + + def permute_mem(self, memory, index): + """This method permutes the scorer memory to synchronize + the memory index with the current output and perform + batched beam search. + + Arguments + --------- + memory : No limit + The memory variables input for this timestep. + index : torch.Tensor + (batch_size, beam_size). The index of the previous path. + + Returns + ------- + memory + """ + if isinstance(memory, tuple): + memory_0 = torch.index_select(memory[0], dim=1, index=index) + memory_1 = torch.index_select(memory[1], dim=1, index=index) + memory = (memory_0, memory_1) + else: + memory = torch.index_select(memory, dim=1, index=index) + return memory + + def reset_mem(self, x, enc_lens): + """This method implement the resetting of + memory variables for the RNNLM scorer. + + Arguments + --------- + x : torch.Tensor + The precomputed encoder states to be used when decoding. + (ex. the encoded speech representation to be attended). + enc_lens : torch.Tensor + The speechbrain-style relative length. + """ + pass + + +class TransformerLMScorer(BaseScorerInterface): + """A wrapper of TransformerLM based on BaseScorerInterface. + + The TransformerLMScorer is used to provide the TransformerLM scores + of the next input tokens based on the current timestep input and the + previous scorer states. + + Arguments + --------- + language_model : torch.nn.Module + A Transformer-based language model. + temperature : float + Temperature factor applied to softmax. It changes the probability + distribution, being softer when T>1 and sharper with T<1. (default: 1.0) + + Example + ------- + >>> from speechbrain.nnet.linear import Linear + >>> from speechbrain.lobes.models.transformer.TransformerASR import ( + ... TransformerASR, + ... ) + >>> from speechbrain.lobes.models.transformer.TransformerLM import ( + ... TransformerLM, + ... ) + >>> from speechbrain.decoders import ( + ... S2STransformerBeamSearcher, + ... TransformerLMScorer, + ... CTCScorer, + ... ScorerBuilder, + ... ) + >>> input_size = 17 + >>> vocab_size = 11 + >>> d_model = 128 + >>> net = TransformerASR( + ... tgt_vocab=vocab_size, + ... input_size=input_size, + ... d_model=d_model, + ... nhead=8, + ... num_encoder_layers=1, + ... num_decoder_layers=1, + ... d_ffn=256, + ... activation=torch.nn.GELU, + ... ) + >>> lm_model = TransformerLM( + ... vocab=vocab_size, + ... d_model=d_model, + ... nhead=8, + ... num_encoder_layers=1, + ... num_decoder_layers=0, + ... d_ffn=256, + ... activation=torch.nn.GELU, + ... ) + >>> n_channels = 6 + >>> ctc_lin = Linear(input_size=d_model, n_neurons=vocab_size) + >>> seq_lin = Linear(input_size=d_model, n_neurons=vocab_size) + >>> eos_index = 2 + >>> ctc_scorer = CTCScorer( + ... ctc_fc=ctc_lin, + ... blank_index=0, + ... eos_index=eos_index, + ... ) + >>> transformerlm_scorer = TransformerLMScorer( + ... language_model=lm_model, + ... temperature=1.15, + ... ) + >>> ctc_weight_decode = 0.4 + >>> lm_weight = 0.6 + >>> scorer = ScorerBuilder( + ... full_scorers=[transformerlm_scorer, ctc_scorer], + ... weights={"transformerlm": lm_weight, "ctc": ctc_weight_decode}, + ... ) + >>> beam_size = 5 + >>> searcher = S2STransformerBeamSearcher( + ... modules=[net, seq_lin], + ... bos_index=1, + ... eos_index=eos_index, + ... min_decode_ratio=0.0, + ... max_decode_ratio=1.0, + ... using_eos_threshold=False, + ... beam_size=beam_size, + ... temperature=1.15, + ... scorer=scorer, + ... ) + >>> batch_size = 2 + >>> wav_len = torch.ones([batch_size]) + >>> src = torch.rand([batch_size, n_channels, input_size]) + >>> tgt = torch.randint(0, vocab_size, [batch_size, n_channels]) + >>> enc, dec = net.forward(src, tgt) + >>> hyps, _, _, _ = searcher(enc, wav_len) + """ + + def __init__(self, language_model, temperature=1.0): + self.lm = language_model + self.lm.eval() + self.temperature = temperature + self.softmax = sb.nnet.activations.Softmax(apply_log=True) + + def score(self, inp_tokens, memory, candidates, attn): + """This method scores the new beams based on the + TransformerLM scores computed over the previous tokens. + + Arguments + --------- + inp_tokens : torch.Tensor + The input tensor of the current timestep. + memory : No limit + The scorer states for this timestep. + candidates : torch.Tensor + (batch_size x beam_size, scorer_beam_size). + The top-k candidates to be scored after the full scorers. + If None, scorers will score on full vocabulary set. + attn : torch.Tensor + The attention weight to be used in CoverageScorer or CTCScorer. + + Returns + ------- + log_probs : torch.Tensor + memory + """ + with torch.no_grad(): + if memory is None: + memory = torch.empty( + inp_tokens.size(0), 0, device=inp_tokens.device + ) + # Append the predicted token of the previous step to existing memory. + memory = torch.cat([memory, inp_tokens.unsqueeze(1)], dim=-1) + if not next(self.lm.parameters()).is_cuda: + self.lm.to(inp_tokens.device) + logits = self.lm(memory) + log_probs = self.softmax(logits / self.temperature) + return log_probs[:, -1, :], memory + + def permute_mem(self, memory, index): + """This method permutes the scorer memory to synchronize + the memory index with the current output and perform + batched beam search. + + Arguments + --------- + memory : No limit + The memory variables input for this timestep. + index : torch.Tensor + (batch_size, beam_size). The index of the previous path. + + Returns + ------- + memory + """ + memory = torch.index_select(memory, dim=0, index=index) + return memory + + def reset_mem(self, x, enc_lens): + """This method implement the resetting of + memory variables for the RNNLM scorer. + + Arguments + --------- + x : torch.Tensor + The precomputed encoder states to be used when decoding. + (ex. the encoded speech representation to be attended). + enc_lens : torch.Tensor + The speechbrain-style relative length. + """ + pass + + +class KenLMScorer(BaseScorerInterface): + """KenLM N-gram scorer. + + This scorer is based on KenLM, which is a fast and efficient + N-gram language model toolkit. It is used to provide the n-gram scores + of the next input tokens. + + This scorer is dependent on the KenLM package. It can be installed + with the following command: + > pip install https://github.com/kpu/kenlm/archive/master.zip + + Note: The KenLM scorer is computationally expensive. It is recommended + to use it as a partial scorer to score on the top-k candidates instead + of the full vocabulary set. + + Arguments + --------- + lm_path : str + The path of ngram model. + vocab_size: int + The total number of tokens. + token_list : list + The tokens set. + + Example + ------- + # >>> from speechbrain.nnet.linear import Linear + # >>> from speechbrain.nnet.RNN import AttentionalRNNDecoder + # >>> from speechbrain.decoders import S2SRNNBeamSearcher, KenLMScorer, ScorerBuilder + # >>> input_size=17 + # >>> vocab_size=11 + # >>> lm_path='path/to/kenlm_model.arpa' # or .bin + # >>> token_list=['', '', '', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i'] + # >>> emb = torch.nn.Embedding( + # ... embedding_dim=input_size, + # ... num_embeddings=vocab_size, + # ... ) + # >>> d_model=7 + # >>> dec = AttentionalRNNDecoder( + # ... rnn_type="gru", + # ... attn_type="content", + # ... hidden_size=3, + # ... attn_dim=3, + # ... num_layers=1, + # ... enc_dim=d_model, + # ... input_size=input_size, + # ... ) + # >>> n_channels=3 + # >>> seq_lin = Linear(input_shape=[d_model, n_channels], n_neurons=vocab_size) + # >>> kenlm_weight = 0.4 + # >>> kenlm_model = KenLMScorer( + # ... lm_path=lm_path, + # ... vocab_size=vocab_size, + # ... token_list=token_list, + # ... ) + # >>> scorer = ScorerBuilder( + # ... full_scorers=[kenlm_model], + # ... weights={'kenlm': kenlm_weight} + # ... ) + # >>> beam_size=5 + # >>> searcher = S2SRNNBeamSearcher( + # ... embedding=emb, + # ... decoder=dec, + # ... linear=seq_lin, + # ... bos_index=1, + # ... eos_index=2, + # ... min_decode_ratio=0.0, + # ... max_decode_ratio=1.0, + # ... topk=2, + # ... using_eos_threshold=False, + # ... beam_size=beam_size, + # ... temperature=1.25, + # ... scorer=scorer + # ... ) + # >>> batch_size=2 + # >>> enc = torch.rand([batch_size, n_channels, d_model]) + # >>> wav_len = torch.ones([batch_size]) + # >>> hyps, _, _, _ = searcher(enc, wav_len) + """ + + def __init__(self, lm_path, vocab_size, token_list): + try: + import kenlm + + self.kenlm = kenlm + except ImportError: + MSG = """Couldn't import KenLM + It is an optional dependency; it is not installed with SpeechBrain + by default. Install it with: + > pip install https://github.com/kpu/kenlm/archive/master.zip + """ + raise ImportError(MSG) + self.lm = self.kenlm.Model(lm_path) + self.vocab_size = vocab_size + self.full_candidates = np.arange(self.vocab_size) + self.minus_inf = -1e20 + if len(token_list) != vocab_size: + MSG = "The size of the token_list and vocab_size are not matched." + raise ValueError(MSG) + self.id2char = token_list + + def score(self, inp_tokens, memory, candidates, attn): + """This method scores the new beams based on the + n-gram scores. + + Arguments + --------- + inp_tokens : torch.Tensor + The input tensor of the current timestep. + memory : No limit + The scorer states for this timestep. + candidates : torch.Tensor + (batch_size x beam_size, scorer_beam_size). + The top-k candidates to be scored after the full scorers. + If None, scorers will score on full vocabulary set. + attn : torch.Tensor + The attention weight to be used in CoverageScorer or CTCScorer. + + Returns + ------- + scores : torch.Tensor + (new_memory, new_scoring_table) : tuple + """ + n_bh = inp_tokens.size(0) + scale = 1.0 / np.log10(np.e) + + if memory is None: + state = self.kenlm.State() + state = np.array([state] * n_bh) + scoring_table = np.ones(n_bh) + else: + state, scoring_table = memory + + # Perform full scorer mode, not recommend + if candidates is None: + candidates = [self.full_candidates] * n_bh + + # Store new states and scores + scores = np.ones((n_bh, self.vocab_size)) * self.minus_inf + new_memory = np.zeros((n_bh, self.vocab_size), dtype=object) + new_scoring_table = np.ones((n_bh, self.vocab_size)) * -1 + # Scoring + for i in range(n_bh): + if scoring_table[i] == -1: + continue + parent_state = state[i] + for token_id in candidates[i]: + char = self.id2char[token_id.item()] + out_state = self.kenlm.State() + score = scale * self.lm.BaseScore(parent_state, char, out_state) + scores[i, token_id] = score + new_memory[i, token_id] = out_state + new_scoring_table[i, token_id] = 1 + scores = torch.from_numpy(scores).float().to(inp_tokens.device) + return scores, (new_memory, new_scoring_table) + + def permute_mem(self, memory, index): + """This method permutes the scorer memory to synchronize + the memory index with the current output and perform + batched beam search. + + Arguments + --------- + memory : No limit + The memory variables input for this timestep. + index : torch.Tensor + (batch_size, beam_size). The index of the previous path. + + Returns + ------- + state : torch.Tensor + scoring_table : torch.Tensor + """ + state, scoring_table = memory + + index = index.cpu().numpy() + # The first index of each sentence. + beam_size = index.shape[1] + beam_offset = self.batch_index * beam_size + hyp_index = ( + index + + np.broadcast_to(np.expand_dims(beam_offset, 1), index.shape) + * self.vocab_size + ) + hyp_index = hyp_index.reshape(-1) + # Update states + state = state.reshape(-1) + state = state[hyp_index] + scoring_table = scoring_table.reshape(-1) + scoring_table = scoring_table[hyp_index] + return state, scoring_table + + def reset_mem(self, x, enc_lens): + """This method implement the resetting of + memory variables for the KenLM scorer. + + Arguments + --------- + x : torch.Tensor + The precomputed encoder states to be used when decoding. + (ex. the encoded speech representation to be attended). + enc_lens : torch.Tensor + The speechbrain-style relative length. + """ + state = self.kenlm.State() + self.lm.NullContextWrite(state) + self.batch_index = np.arange(x.size(0)) + + +class CoverageScorer(BaseScorerInterface): + """A coverage penalty scorer to prevent looping of hyps, + where ```coverage``` is the cumulative attention probability vector. + Reference: https://arxiv.org/pdf/1612.02695.pdf, + https://arxiv.org/pdf/1808.10792.pdf + + Arguments + --------- + vocab_size: int + The total number of tokens. + threshold: float + The penalty increases when the coverage of a frame is more + than given threshold. (default: 0.5) + + Example + ------- + >>> from speechbrain.nnet.linear import Linear + >>> from speechbrain.lobes.models.RNNLM import RNNLM + >>> from speechbrain.nnet.RNN import AttentionalRNNDecoder + >>> from speechbrain.decoders import ( + ... S2SRNNBeamSearcher, + ... RNNLMScorer, + ... CoverageScorer, + ... ScorerBuilder, + ... ) + >>> input_size = 17 + >>> vocab_size = 11 + >>> emb = torch.nn.Embedding( + ... num_embeddings=vocab_size, embedding_dim=input_size + ... ) + >>> d_model = 7 + >>> dec = AttentionalRNNDecoder( + ... rnn_type="gru", + ... attn_type="content", + ... hidden_size=3, + ... attn_dim=3, + ... num_layers=1, + ... enc_dim=d_model, + ... input_size=input_size, + ... ) + >>> n_channels = 3 + >>> seq_lin = Linear( + ... input_shape=[d_model, n_channels], n_neurons=vocab_size + ... ) + >>> lm_weight = 0.4 + >>> coverage_penalty = 1.0 + >>> lm_model = RNNLM( + ... embedding_dim=d_model, + ... output_neurons=vocab_size, + ... dropout=0.0, + ... rnn_neurons=128, + ... dnn_neurons=64, + ... return_hidden=True, + ... ) + >>> rnnlm_scorer = RNNLMScorer( + ... language_model=lm_model, + ... temperature=1.25, + ... ) + >>> coverage_scorer = CoverageScorer(vocab_size=vocab_size) + >>> scorer = ScorerBuilder( + ... full_scorers=[rnnlm_scorer, coverage_scorer], + ... weights={"rnnlm": lm_weight, "coverage": coverage_penalty}, + ... ) + >>> beam_size = 5 + >>> searcher = S2SRNNBeamSearcher( + ... embedding=emb, + ... decoder=dec, + ... linear=seq_lin, + ... bos_index=1, + ... eos_index=2, + ... min_decode_ratio=0.0, + ... max_decode_ratio=1.0, + ... topk=2, + ... using_eos_threshold=False, + ... beam_size=beam_size, + ... temperature=1.25, + ... scorer=scorer, + ... ) + >>> batch_size = 2 + >>> enc = torch.rand([batch_size, n_channels, d_model]) + >>> wav_len = torch.ones([batch_size]) + >>> hyps, _, _, _ = searcher(enc, wav_len) + """ + + def __init__(self, vocab_size, threshold=0.5): + self.vocab_size = vocab_size + self.threshold = threshold + # Use time_step to normalize the coverage over steps + self.time_step = 0 + + def score(self, inp_tokens, coverage, candidates, attn): + """This method scores the new beams based on the + Coverage scorer. + + Arguments + --------- + inp_tokens : torch.Tensor + The input tensor of the current timestep. + coverage : No limit + The scorer states for this timestep. + candidates : torch.Tensor + (batch_size x beam_size, scorer_beam_size). + The top-k candidates to be scored after the full scorers. + If None, scorers will score on full vocabulary set. + attn : torch.Tensor + The attention weight to be used in CoverageScorer or CTCScorer. + + Returns + ------- + score : torch.Tensor + coverage + """ + n_bh = attn.size(0) + self.time_step += 1 + + if coverage is None: + coverage = torch.zeros_like(attn, device=attn.device) + + # Current coverage + if len(attn.size()) > 2: + # the attn of transformer is [batch_size x beam_size, current_step, source_len] + coverage = torch.sum(attn, dim=1) + else: + coverage = coverage + attn + + # Compute coverage penalty and add it to scores + penalty = torch.max( + coverage, coverage.clone().fill_(self.threshold) + ).sum(-1) + penalty = penalty - coverage.size(-1) * self.threshold + penalty = penalty.view(n_bh).unsqueeze(1).expand(-1, self.vocab_size) + return -1 * penalty / self.time_step, coverage + + def permute_mem(self, coverage, index): + """This method permutes the scorer memory to synchronize + the memory index with the current output and perform + batched beam search. + + Arguments + --------- + coverage : No limit + The memory variables input for this timestep. + index : torch.Tensor + (batch_size, beam_size). The index of the previous path. + + Returns + ------- + coverage + """ + # Update coverage + coverage = torch.index_select(coverage, dim=0, index=index) + return coverage + + def reset_mem(self, x, enc_lens): + """This method implement the resetting of + memory variables for the RNNLM scorer. + + Arguments + --------- + x : torch.Tensor + The precomputed encoder states to be used when decoding. + (ex. the encoded speech representation to be attended). + enc_lens : torch.Tensor + The speechbrain-style relative length. + """ + self.time_step = 0 + + +class LengthScorer(BaseScorerInterface): + """A length rewarding scorer. + + The LengthScorer is used to provide the length rewarding scores. + It is used to prevent the beam search from favoring short hypotheses. + + Note: length_normalization is not compatible with this scorer. Make sure + to set is to False when using LengthScorer. + + Arguments + --------- + vocab_size: int + The total number of tokens. + + Example + ------- + >>> from speechbrain.nnet.linear import Linear + >>> from speechbrain.lobes.models.RNNLM import RNNLM + >>> from speechbrain.nnet.RNN import AttentionalRNNDecoder + >>> from speechbrain.decoders import ( + ... S2SRNNBeamSearcher, + ... RNNLMScorer, + ... CoverageScorer, + ... ScorerBuilder, + ... ) + >>> input_size = 17 + >>> vocab_size = 11 + >>> emb = torch.nn.Embedding( + ... num_embeddings=vocab_size, embedding_dim=input_size + ... ) + >>> d_model = 7 + >>> dec = AttentionalRNNDecoder( + ... rnn_type="gru", + ... attn_type="content", + ... hidden_size=3, + ... attn_dim=3, + ... num_layers=1, + ... enc_dim=d_model, + ... input_size=input_size, + ... ) + >>> n_channels = 3 + >>> seq_lin = Linear( + ... input_shape=[d_model, n_channels], n_neurons=vocab_size + ... ) + >>> lm_weight = 0.4 + >>> length_weight = 1.0 + >>> lm_model = RNNLM( + ... embedding_dim=d_model, + ... output_neurons=vocab_size, + ... dropout=0.0, + ... rnn_neurons=128, + ... dnn_neurons=64, + ... return_hidden=True, + ... ) + >>> rnnlm_scorer = RNNLMScorer( + ... language_model=lm_model, + ... temperature=1.25, + ... ) + >>> length_scorer = LengthScorer(vocab_size=vocab_size) + >>> scorer = ScorerBuilder( + ... full_scorers=[rnnlm_scorer, length_scorer], + ... weights={"rnnlm": lm_weight, "length": length_weight}, + ... ) + >>> beam_size = 5 + >>> searcher = S2SRNNBeamSearcher( + ... embedding=emb, + ... decoder=dec, + ... linear=seq_lin, + ... bos_index=1, + ... eos_index=2, + ... min_decode_ratio=0.0, + ... max_decode_ratio=1.0, + ... topk=2, + ... using_eos_threshold=False, + ... beam_size=beam_size, + ... temperature=1.25, + ... length_normalization=False, + ... scorer=scorer, + ... ) + >>> batch_size = 2 + >>> enc = torch.rand([batch_size, n_channels, d_model]) + >>> wav_len = torch.ones([batch_size]) + >>> hyps, _, _, _ = searcher(enc, wav_len) + """ + + def __init__(self, vocab_size): + self.vocab_size = vocab_size + + def score(self, inp_tokens, memory, candidates, attn): + """This method scores the new beams based on the + Length scorer. + + Arguments + --------- + inp_tokens : torch.Tensor + The input tensor of the current timestep. + memory : No limit + The scorer states for this timestep. + candidates : torch.Tensor + (batch_size x beam_size, scorer_beam_size). + The top-k candidates to be scored after the full scorers. + If None, scorers will score on full vocabulary set. + attn : torch.Tensor + The attention weight to be used in CoverageScorer or CTCScorer. + + Returns + ------- + torch.Tensor + Scores + None + """ + return ( + torch.tensor( + [1.0], device=inp_tokens.device, dtype=inp_tokens.dtype + ).expand(inp_tokens.size(0), self.vocab_size), + None, + ) + + +class ScorerBuilder: + """Builds scorer instance for beamsearch. + + The ScorerBuilder class is responsible for building a scorer instance for + beam search. It takes weights for full and partial scorers, as well as + instances of full and partial scorer classes. It combines the scorers based + on the weights specified and provides methods for scoring tokens, permuting + scorer memory, and resetting scorer memory. + + This is the class to be used for building scorer instances for beam search. + + See speechbrain.decoders.seq2seq.S2SBeamSearcher() + + Arguments + --------- + weights : dict + Weights of full/partial scorers specified. + full_scorers : list + Scorers that score on full vocabulary set. + partial_scorers : list + Scorers that score on pruned tokens to prevent computation overhead. + Partial scoring is performed after full scorers. + scorer_beam_scale : float + The scale decides the number of pruned tokens for partial scorers: + int(beam_size * scorer_beam_scale). + + Example + ------- + >>> from speechbrain.nnet.linear import Linear + >>> from speechbrain.lobes.models.transformer.TransformerASR import ( + ... TransformerASR, + ... ) + >>> from speechbrain.lobes.models.transformer.TransformerLM import ( + ... TransformerLM, + ... ) + >>> from speechbrain.decoders import ( + ... S2STransformerBeamSearcher, + ... TransformerLMScorer, + ... CoverageScorer, + ... CTCScorer, + ... ScorerBuilder, + ... ) + >>> input_size = 17 + >>> vocab_size = 11 + >>> d_model = 128 + >>> net = TransformerASR( + ... tgt_vocab=vocab_size, + ... input_size=input_size, + ... d_model=d_model, + ... nhead=8, + ... num_encoder_layers=1, + ... num_decoder_layers=1, + ... d_ffn=256, + ... activation=torch.nn.GELU, + ... ) + >>> lm_model = TransformerLM( + ... vocab=vocab_size, + ... d_model=d_model, + ... nhead=8, + ... num_encoder_layers=1, + ... num_decoder_layers=0, + ... d_ffn=256, + ... activation=torch.nn.GELU, + ... ) + >>> n_channels = 6 + >>> ctc_lin = Linear(input_size=d_model, n_neurons=vocab_size) + >>> seq_lin = Linear(input_size=d_model, n_neurons=vocab_size) + >>> eos_index = 2 + >>> ctc_scorer = CTCScorer( + ... ctc_fc=ctc_lin, + ... blank_index=0, + ... eos_index=eos_index, + ... ) + >>> transformerlm_scorer = TransformerLMScorer( + ... language_model=lm_model, + ... temperature=1.15, + ... ) + >>> coverage_scorer = CoverageScorer(vocab_size=vocab_size) + >>> ctc_weight_decode = 0.4 + >>> lm_weight = 0.6 + >>> coverage_penalty = 1.0 + >>> scorer = ScorerBuilder( + ... full_scorers=[transformerlm_scorer, coverage_scorer], + ... partial_scorers=[ctc_scorer], + ... weights={ + ... "transformerlm": lm_weight, + ... "ctc": ctc_weight_decode, + ... "coverage": coverage_penalty, + ... }, + ... ) + >>> beam_size = 5 + >>> searcher = S2STransformerBeamSearcher( + ... modules=[net, seq_lin], + ... bos_index=1, + ... eos_index=eos_index, + ... min_decode_ratio=0.0, + ... max_decode_ratio=1.0, + ... using_eos_threshold=False, + ... beam_size=beam_size, + ... topk=3, + ... temperature=1.15, + ... scorer=scorer, + ... ) + >>> batch_size = 2 + >>> wav_len = torch.ones([batch_size]) + >>> src = torch.rand([batch_size, n_channels, input_size]) + >>> tgt = torch.randint(0, vocab_size, [batch_size, n_channels]) + >>> enc, dec = net.forward(src, tgt) + >>> hyps, _, _, _ = searcher(enc, wav_len) + """ + + def __init__( + self, + weights=dict(), + full_scorers=list(), + partial_scorers=list(), + scorer_beam_scale=2, + ): + assert len(weights) == len(full_scorers) + len(partial_scorers), ( + "Weights and scorers are not matched." + ) + + self.scorer_beam_scale = scorer_beam_scale + all_scorer_names = [ + k.lower().split("scorer")[0] + for k in globals().keys() + if k.endswith("Scorer") + ] + full_scorer_names = [ + impl.__class__.__name__.lower().split("scorer")[0] + for impl in full_scorers + ] + partial_scorer_names = [ + impl.__class__.__name__.lower().split("scorer")[0] + for impl in partial_scorers + ] + + # Have a default 0.0 weight for scorer not specified + init_weights = dict.fromkeys(all_scorer_names, 0.0) + self.weights = {**init_weights, **weights} + self.full_scorers = dict(zip(full_scorer_names, full_scorers)) + self.partial_scorers = dict(zip(partial_scorer_names, partial_scorers)) + + # Check if scorers are valid + self._validate_scorer(all_scorer_names) + + def score(self, inp_tokens, memory, attn, log_probs, beam_size): + """This method scores tokens in vocabulary based on defined full scorers + and partial scorers. Scores will be added to the log probs for beamsearch. + + Arguments + --------- + inp_tokens : torch.Tensor + See BaseScorerInterface(). + memory : dict[str, scorer memory] + The states of scorers for this timestep. + attn : torch.Tensor + See BaseScorerInterface(). + log_probs : torch.Tensor + (batch_size x beam_size, vocab_size). The log probs at this timestep. + beam_size : int + The beam size. + + Returns + ------- + log_probs : torch.Tensor + (batch_size x beam_size, vocab_size). Log probs updated by scorers. + new_memory : dict[str, scorer memory] + The updated states of scorers. + """ + new_memory = dict() + # score full candidates + for k, impl in self.full_scorers.items(): + if k == "ctc": + # block blank token if CTC is used + log_probs[:, impl.blank_index] = impl.ctc_score.minus_inf + + score, new_memory[k] = impl.score(inp_tokens, memory[k], None, attn) + log_probs += score * self.weights[k] + + # Select candidates from the results of full scorers for partial scorers + # clamp number of candidates to [1, vocab_size] to avoid invalid topk size + num_candidates = int(beam_size * self.scorer_beam_scale) + num_candidates = max(1, min(num_candidates, log_probs.shape[-1])) + candidates = log_probs.topk(num_candidates, dim=-1).indices + + # score pruned tokens candidates + for k, impl in self.partial_scorers.items(): + score, new_memory[k] = impl.score( + inp_tokens, memory[k], candidates, attn + ) + log_probs += score * self.weights[k] + + return log_probs, new_memory + + def permute_scorer_mem(self, memory, index, candidates): + """Update memory variables of scorers to synchronize + the memory index with the current output and perform + batched beam search. + + Arguments + --------- + memory : dict[str, scorer memory] + The states of scorers for this timestep. + index : torch.Tensor + (batch_size x beam_size). The index of the previous path. + candidates : torch.Tensor + (batch_size, beam_size). The index of the topk candidates. + + Returns + ------- + memory : dict + """ + for k, impl in self.full_scorers.items(): + # ctc scorer should always be scored by candidates + if k == "ctc" or k == "kenlm": + memory[k] = impl.permute_mem(memory[k], candidates) + continue + memory[k] = impl.permute_mem(memory[k], index) + for k, impl in self.partial_scorers.items(): + memory[k] = impl.permute_mem(memory[k], candidates) + return memory + + def reset_scorer_mem(self, x, enc_lens): + """Reset memory variables for scorers. + + Arguments + --------- + x : torch.Tensor + See BaseScorerInterface(). + enc_lens : torch.Tensor + See BaseScorerInterface(). + + Returns + ------- + memory : dict + """ + memory = dict() + for k, impl in {**self.full_scorers, **self.partial_scorers}.items(): + memory[k] = impl.reset_mem(x, enc_lens) + return memory + + def _validate_scorer(self, scorer_names): + """These error messages indicate scorers are not properly set. + + Arguments + --------- + scorer_names : list + Prefix of scorers defined in speechbrain.decoders.scorer. + """ + if len(self.weights) > len(scorer_names): + raise ValueError( + f"The keys of weights should be named in {scorer_names}" + ) + + if not 0.0 <= self.weights["ctc"] <= 1.0: + raise ValueError("ctc_weight should not > 1.0 and < 0.0") + + if self.weights["ctc"] == 1.0: + if "ctc" not in self.full_scorers.keys(): + raise ValueError( + "CTC scorer should be a full scorer when it's weight is 1.0" + ) + if self.weights["coverage"] > 0.0: + raise ValueError( + "Pure CTC scorer doesn't have attention weights for coverage scorer" + ) + + +class BaseRescorerInterface(BaseScorerInterface): + """A scorer abstraction intended for inheritance by other scoring approaches used in beam search. + + In this approach, a neural network is employed to assign scores to potential text transcripts. + The beam search decoding process produces a collection of the top K hypotheses. + These candidates are subsequently sent to a language model (LM) for ranking. + The ranking is carried out by the LM, which assigns a score to each candidate. + + The score is computed as follows: + + score = beam_search_score + lm_weight * rescorer_score + + See: + - speechbrain.decoders.scorer.RNNLMRescorer + - speechbrain.decoders.scorer.TransformerLMRescorer + - speechbrain.decoders.scorer.HuggingFaceLMRescorer + """ + + def normalize_text(self, text): + """This method should implement the normalization of the text before scoring. + + Arguments + --------- + text : list of str + The text to be normalized. + + Returns + ------- + Normalized text + """ + return text + + def preprocess_func(self, hyps): + """This method should implement the preprocessing of the hypotheses before scoring. + + Arguments + --------- + hyps : list of str + The hypotheses to be preprocessed. + """ + raise NotImplementedError + + def rescore_hyps(self, hyps): + """This method should implement the rescoring of the hypotheses. + + Arguments + --------- + hyps : list of str + The hypotheses to be rescored. + """ + raise NotImplementedError + + def to_device(self, device=None): + """This method should implement the moving of the scorer to a device. + + If device is None, the scorer should be moved to the default device provided + in the constructor. + + Arguments + --------- + device : str + The device to move the scorer to. + """ + raise NotImplementedError + + +class RNNLMRescorer(BaseRescorerInterface): + """A wrapper of RNNLM based on the BaseRescorerInterface. + + Arguments + --------- + language_model : torch.nn.Module + A RNN-based language model. + tokenizer : SentencePieceProcessor + A SentencePiece tokenizer. + device : str + The device to move the scorer to. + temperature : float + Temperature factor applied to softmax. It changes the probability + distribution, being softer when T>1 and sharper with T<1. (default: 1.0) + bos_index : int + The index of the beginning-of-sequence (bos) token. + eos_index : int + The index of the end-of-sequence (eos) token. + pad_index : int + The index of the padding token. + + Note + ---- + This class is intended to be used with a pretrained TransformerLM model. + Please see: https://huggingface.co/speechbrain/asr-crdnn-rnnlm-librispeech + + By default, this model is using SentencePiece tokenizer. + + Example + ------- + >>> import torch + >>> from sentencepiece import SentencePieceProcessor + >>> from speechbrain.lobes.models.RNNLM import RNNLM + >>> from speechbrain.utils.parameter_transfer import Pretrainer + >>> source = "speechbrain/asr-crdnn-rnnlm-librispeech" + >>> lm_model_path = source + "/lm.ckpt" + >>> tokenizer_path = source + "/tokenizer.ckpt" + >>> # define your tokenizer and RNNLM from the HF hub + >>> tokenizer = SentencePieceProcessor() + >>> lm_model = RNNLM( + ... output_neurons=1000, + ... embedding_dim=128, + ... activation=torch.nn.LeakyReLU, + ... dropout=0.0, + ... rnn_layers=2, + ... rnn_neurons=2048, + ... dnn_blocks=1, + ... dnn_neurons=512, + ... return_hidden=True, + ... ) + >>> pretrainer = Pretrainer( + ... collect_in=getfixture("tmp_path"), + ... loadables={ + ... "lm": lm_model, + ... "tokenizer": tokenizer, + ... }, + ... paths={ + ... "lm": lm_model_path, + ... "tokenizer": tokenizer_path, + ... }, + ... ) + >>> _ = pretrainer.collect_files() + >>> pretrainer.load_collected() + >>> from speechbrain.decoders.scorer import RNNLMRescorer, RescorerBuilder + >>> rnnlm_rescorer = RNNLMRescorer( + ... language_model=lm_model, + ... tokenizer=tokenizer, + ... temperature=1.0, + ... bos_index=0, + ... eos_index=0, + ... pad_index=0, + ... ) + >>> # Define a rescorer builder + >>> rescorer = RescorerBuilder( + ... rescorers=[rnnlm_rescorer], weights={"rnnlm": 1.0} + ... ) + >>> # topk hyps + >>> topk_hyps = [["HELLO", "HE LLO", "H E L L O"]] + >>> topk_scores = [[-2, -2, -2]] + >>> rescored_hyps, rescored_scores = rescorer.rescore( + ... topk_hyps, topk_scores + ... ) + >>> # NOTE: the returned hypotheses are already sorted by score. + >>> rescored_hyps # doctest: +SKIP + [['HELLO', 'H E L L O', 'HE LLO']] + >>> # NOTE: as we are returning log-probs, the more it is closer to 0, the better. + >>> rescored_scores # doctest: +SKIP + [[-17.863974571228027, -25.12890625, -26.075977325439453]] + """ + + def __init__( + self, + language_model, + tokenizer, + device="cuda", + temperature=1.0, + bos_index=0, + eos_index=0, + pad_index=0, + ): + self.lm = language_model + self.lm.eval() + self.tokenizer = tokenizer + self.temperature = temperature + self.softmax = sb.nnet.activations.Softmax(apply_log=True) + + self.device = device + self.bos_index = bos_index + self.eos_index = eos_index + self.pad_index = pad_index + + def normalize_text(self, text): + """This method should implement the normalization of the text before scoring. + + Default to uppercasing the text because the (current) language models are trained on + LibriSpeech which is all uppercase. + + Arguments + --------- + text : str + The text to be normalized. + + Returns + ------- + str + The normalized text. + """ + return text.upper() + + def to_device(self, device=None): + """This method moves the scorer to a device. + + If device is None, the scorer is moved to the default device provided + in the constructor. + + Arguments + --------- + device : str + The device to move the scorer to. + """ + if device is None: + self.lm.to(self.device) + else: + self.lm.to(device) + + def preprocess_func(self, topk_hyps): + """This method preprocesses the hypotheses before scoring. + + Arguments + --------- + topk_hyps : list of list of str + The hypotheses to be preprocessed. + + Returns + ------- + padded_hyps : torch.Tensor + The padded hypotheses. + enc_hyps_length : list of int + The length of each hypothesis. + """ + # 1. normalize text + decoded_seq = [] + for batch in topk_hyps: + for seq in batch: + decoded_seq.append(self.normalize_text(seq)) + + # 2. encode text + enc_hyps = [] + for seq in decoded_seq: + enc_hyps.append( + torch.tensor( + [self.bos_index] + + self.tokenizer.encode_as_ids(seq) + + [self.eos_index] + ) + ) + + enc_hyps_length = [enc_seq.shape[0] for enc_seq in enc_hyps] + + # 3. pad sequences + padded_hyps = torch.nn.utils.rnn.pad_sequence( + enc_hyps, batch_first=True, padding_value=self.pad_index + ).to(self.lm.parameters().__next__().device) + + return padded_hyps, enc_hyps_length + + @torch.no_grad() + def rescore_hyps(self, topk_hyps): + """This method implement the rescoring of the hypotheses. + + Arguments + --------- + topk_hyps : list of list of str + The hypotheses to be rescored. + + Returns + ------- + log_probs_scores : torch.Tensor[B * Topk, 1] + The rescored hypotheses scores + """ + # preprocess hypotheses + padded_hyps, enc_hyps_length = self.preprocess_func(topk_hyps) + + bool_mask = [ + [1 if i < length else 0 for i in range(max(enc_hyps_length))] + for length in enc_hyps_length + ] + + bool_mask_tensor = torch.tensor( + bool_mask, dtype=torch.bool, device=padded_hyps.device + ) + + if not next(self.lm.parameters()).is_cuda: + self.lm.to(padded_hyps.device) + + # compute scores + logits, _ = self.lm(padded_hyps) + log_probs = self.softmax(logits / self.temperature) + + target_log_probs = ( + log_probs[:, :-1] + .gather(2, padded_hyps[:, 1:].unsqueeze(2)) + .squeeze(2) + ) + + log_probs_scores = torch.nansum( + target_log_probs * bool_mask_tensor[:, 1:], dim=-1 + ) + + return log_probs_scores + + +class TransformerLMRescorer(BaseRescorerInterface): + """A wrapper of TransformerLM based on the BaseRescorerInterface. + + Arguments + --------- + language_model : torch.nn.Module + A Transformer-based language model. + tokenizer : SentencePieceProcessor + A SentencePiece tokenizer. + device : str + The device to move the scorer to. + temperature : float + Temperature factor applied to softmax. It changes the probability + distribution, being softer when T>1 and sharper with T<1. (default: 1.0) + bos_index : int + The index of the beginning-of-sequence (bos) token. + eos_index : int + The index of the end-of-sequence (eos) token. + pad_index : int + The index of the padding token. + + Note + ---- + This class is intended to be used with a pretrained TransformerLM model. + Please see: https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech + + By default, this model is using SentencePiece tokenizer. + + Example + ------- + >>> import torch + >>> from sentencepiece import SentencePieceProcessor + >>> from speechbrain.lobes.models.transformer.TransformerLM import ( + ... TransformerLM, + ... ) + >>> from speechbrain.utils.parameter_transfer import Pretrainer + >>> source = "speechbrain/asr-transformer-transformerlm-librispeech" + >>> lm_model_path = source + "/lm.ckpt" + >>> tokenizer_path = source + "/tokenizer.ckpt" + >>> tokenizer = SentencePieceProcessor() + >>> lm_model = TransformerLM( + ... vocab=5000, + ... d_model=768, + ... nhead=12, + ... num_encoder_layers=12, + ... num_decoder_layers=0, + ... d_ffn=3072, + ... dropout=0.0, + ... activation=torch.nn.GELU, + ... normalize_before=False, + ... ) + >>> pretrainer = Pretrainer( + ... collect_in=getfixture("tmp_path"), + ... loadables={ + ... "lm": lm_model, + ... "tokenizer": tokenizer, + ... }, + ... paths={ + ... "lm": lm_model_path, + ... "tokenizer": tokenizer_path, + ... }, + ... ) + >>> _ = pretrainer.collect_files() + >>> pretrainer.load_collected() + >>> from speechbrain.decoders.scorer import ( + ... TransformerLMRescorer, + ... RescorerBuilder, + ... ) + >>> transformerlm_rescorer = TransformerLMRescorer( + ... language_model=lm_model, + ... tokenizer=tokenizer, + ... temperature=1.0, + ... bos_index=1, + ... eos_index=2, + ... pad_index=0, + ... ) + >>> rescorer = RescorerBuilder( + ... rescorers=[transformerlm_rescorer], weights={"transformerlm": 1.0} + ... ) + >>> topk_hyps = [["HELLO", "HE LLO", "H E L L O"]] + >>> topk_scores = [[-2, -2, -2]] + >>> rescored_hyps, rescored_scores = rescorer.rescore( + ... topk_hyps, topk_scores + ... ) + >>> # NOTE: the returned hypotheses are already sorted by score. + >>> rescored_hyps # doctest: +SKIP + [["HELLO", "HE L L O", "HE LLO"]] + >>> # NOTE: as we are returning log-probs, the more it is closer to 0, the better. + >>> rescored_scores # doctest: +SKIP + [[-17.863974571228027, -25.12890625, -26.075977325439453]] + """ + + def __init__( + self, + language_model, + tokenizer, + device="cuda", + temperature=1.0, + bos_index=0, + eos_index=0, + pad_index=0, + ): + self.lm = language_model + self.lm.eval() + + self.tokenizer = tokenizer + self.temperature = temperature + self.softmax = sb.nnet.activations.Softmax(apply_log=True) + + self.device = device + self.bos_index = bos_index + self.eos_index = eos_index + self.pad_index = pad_index + + def normalize_text(self, text): + """This method should implement the normalization of the text before scoring. + + Default to uppercasing the text because the language models are trained on + LibriSpeech. + + Arguments + --------- + text : str + The text to be normalized. + + Returns + ------- + str + The normalized text. + """ + return text.upper() + + def to_device(self, device=None): + """This method moves the scorer to a device. + + If device is None, the scorer is moved to the default device provided + in the constructor. + + This method is dynamically called in the recipes when the stage is equal + to TEST. + + Arguments + --------- + device : str + The device to move the scorer to. + """ + if device is None: + self.lm.to(self.device) + else: + self.lm.to(device) + + def preprocess_func(self, topk_hyps): + """This method preprocesses the hypotheses before scoring. + + Arguments + --------- + topk_hyps : list of list of str + The hypotheses to be preprocessed. + + Returns + ------- + padded_hyps : torch.Tensor + The padded hypotheses. + enc_hyps_length : list of int + The length of each hypothesis. + """ + # 1. normalize + decoded_seq = [] + for batch in topk_hyps: + for seq in batch: + decoded_seq.append(self.normalize_text(seq)) + + # 2. encode text + enc_hyps = [] + for seq in decoded_seq: + enc_hyps.append( + torch.tensor( + [self.bos_index] + + self.tokenizer.encode_as_ids(seq) + + [self.eos_index] + ) + ) + + enc_hyps_length = [enc_seq.shape[0] for enc_seq in enc_hyps] + + # 3. pad sequences + padded_hyps = torch.nn.utils.rnn.pad_sequence( + enc_hyps, batch_first=True, padding_value=self.pad_index + ).to(self.lm.parameters().__next__().device) + + return padded_hyps, enc_hyps_length + + @torch.no_grad() + def rescore_hyps(self, topk_hyps): + """This method implement the rescoring of the hypotheses. + + Arguments + --------- + topk_hyps : list of list of str + The hypotheses to be rescored. + + Returns + ------- + log_probs_scores : torch.Tensor[B * Topk, 1] + The rescored hypotheses scores + """ + # preprocess hypotheses + padded_hyps, enc_hyps_length = self.preprocess_func(topk_hyps) + + bool_mask = [ + [1 if i < length else 0 for i in range(max(enc_hyps_length))] + for length in enc_hyps_length + ] + + bool_mask_tensor = torch.tensor( + bool_mask, dtype=torch.bool, device=padded_hyps.device + ) + + if not next(self.lm.parameters()).is_cuda: + self.lm.to(padded_hyps.device) + + # compute scores + logits = self.lm(padded_hyps) + log_probs = self.softmax(logits / self.temperature) + + log_probs[:, :, self.pad_index] = float("-inf") + + target_log_probs = ( + log_probs[:, :-1] + .gather(2, padded_hyps[:, 1:].unsqueeze(2)) + .squeeze(2) + ) + + target_log_probs = target_log_probs - log_probs[:, :-1].logsumexp( + dim=-1 + ) + log_probs_scores = torch.nansum( + target_log_probs * bool_mask_tensor[:, 1:], dim=-1 + ) + + return log_probs_scores + + +class HuggingFaceLMRescorer(BaseRescorerInterface): + """A wrapper of HuggingFace's TransformerLM based on the BaseRescorerInterface. + + Arguments + --------- + model_name : str + The name of the model to be loaded. + device : str + The device to be used for scoring. (default: "cuda") + + Example + ------- + >>> from speechbrain.decoders.scorer import ( + ... HuggingFaceLMRescorer, + ... RescorerBuilder, + ... ) + >>> source = "gpt2-medium" + >>> huggingfacelm_rescorer = HuggingFaceLMRescorer( + ... model_name=source, + ... ) + >>> rescorer = RescorerBuilder( + ... rescorers=[huggingfacelm_rescorer], weights={"huggingfacelm": 1.0} + ... ) + >>> topk_hyps = [ + ... ["Hello everyone.", "Hell o every one.", "Hello every one"] + ... ] + >>> topk_scores = [[-2, -2, -2]] + >>> rescored_hyps, rescored_scores = rescorer.rescore( + ... topk_hyps, topk_scores + ... ) + >>> # NOTE: the returned hypotheses are already sorted by score. + >>> rescored_hyps # doctest: +SKIP + [['Hello everyone.', 'Hello every one', 'Hell o every one.']] + >>> # NOTE: as we are returning log-probs, the more it is closer to 0, the better. + >>> rescored_scores # doctest: +SKIP + [[-20.03631591796875, -27.615638732910156, -42.662353515625]] + """ + + def __init__( + self, + model_name, + device="cuda", + ): + self.model_name = model_name + self.device = device + + try: + from transformers import AutoModelForCausalLM, AutoTokenizer + except ImportError: + raise ImportError( + "Please install transformers with: pip install transformers" + ) + + self.lm = AutoModelForCausalLM.from_pretrained(self.model_name).eval() + + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name, use_fast=True + ) + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = "<|pad|>" + self.tokenizer.add_special_tokens( + {"additional_special_tokens": [self.tokenizer.pad_token]} + ) + self.lm.resize_token_embeddings( + len(self.tokenizer), pad_to_multiple_of=32 + ) + + self.bos_token = self.tokenizer.bos_token + self.eos_token = self.tokenizer.eos_token + + def to_device(self, device=None): + """This method moves the scorer to a device. + + If device is None, the scorer is moved to the default device provided + in the constructor. + + This method is dynamically called in the recipes when the stage is equal + to TEST. + + Arguments + --------- + device : str + The device to move the scorer to. + """ + if device is None: + self.lm.to(self.device) + else: + self.lm.to(device) + + def normalize_text(self, text): + """This method should implement the normalization of the text before scoring. + + Arguments + --------- + text : str + The text to be normalized. + + Returns + ------- + normalized_text : str + The normalized text. + In this case we do not apply any normalization. However, this method + can be overridden to apply any normalization. + """ + return text + + def _add_special_tokens(self, text): + """This method adds the special tokens to the text. + + Arguments + --------- + text : str + The text to be augmented. + + Returns + ------- + augmented_text : str + The augmented text. + """ + return self.bos_token + text + self.eos_token + + def preprocess_func(self, topk_hyps): + """This method preprocesses the hypotheses before scoring. + + Arguments + --------- + topk_hyps : list of str + The hypotheses to be preprocessed. + + Returns + ------- + encoding : tensor + The encoding of the hypotheses. + """ + # 1. normalize + normalized_hyps = [] + for batch in topk_hyps: + for seq in batch: + normalized_hyps.append(self.normalize_text(seq)) + + text_augmented_with_tokens = list( + map(self._add_special_tokens, normalized_hyps) + ) + encoding = self.tokenizer( + text_augmented_with_tokens, return_tensors="pt", padding=True + ) + return encoding + + @torch.no_grad() + def rescore_hyps(self, topk_hyps): + """This method implement the rescoring of the hypotheses. + + Arguments + --------- + topk_hyps : list of list of str + The hypotheses to be rescored. + + Returns + ------- + log_probs_scores : torch.Tensor[B * Topk, 1] + The rescored hypotheses scores + """ + encoding = self.preprocess_func(topk_hyps) + + ids = encoding["input_ids"].to(self.lm.device) + attention_mask = encoding["attention_mask"].to(self.lm.device) + logits = self.lm(ids, attention_mask=attention_mask)[0] + + logits[:, :, self.tokenizer.pad_token_id :] = float("-inf") + + target_log_probs = ( + logits[:, :-1].gather(2, ids[:, 1:].unsqueeze(2)).squeeze(2) + ) + + target_log_probs = target_log_probs - logits[:, :-1].logsumexp(dim=-1) + log_probs_scores = torch.nansum( + target_log_probs * attention_mask[:, 1:], dim=-1 + ) + + return log_probs_scores + + +class RescorerBuilder: + """Builds rescorer instance for beamsearch. + + The RescorerBuilder class is responsible for building a scorer instance for + beam search. It takes weights and rescorers classes. It combines the scorers based + on the weights specified and provides methods for rescoring text. + + This is the class to be used for building rescorer instances for beam search. + + Arguments + --------- + weights : dict + Weights of rescorers specified. + rescorers : list + Rescorers that re-ranks topk hypotheses. + """ + + def __init__( + self, + weights=dict(), + rescorers=list(), + ): + assert len(weights) == len(rescorers), ( + "Weights and rescorers are not matched." + ) + + self.weights = weights + + all_rescorer_names = [ + k.lower().split("rescorer")[0] + for k in globals().keys() + if k.endswith("Rescorer") + ] + full_rescorer_names = [ + impl.__class__.__name__.lower().split("rescorer")[0] + for impl in rescorers + ] + + # Have a default 0.0 weight for scorer not specified + init_weights = dict.fromkeys(all_rescorer_names, 0.0) + self.weights = {**init_weights, **weights} + self.rescorers = dict(zip(full_rescorer_names, rescorers)) + + self._validate_scorer(all_rescorer_names) + + def rescore(self, topk_candidates, topk_scores): + """This method rescores the topk candidates. + + Arguments + --------- + topk_candidates : list of list of str + The topk candidates to be rescored. + topk_scores : list of list of float + The scores of the topk candidates. + + Returns + ------- + output_candidates : list of list of str + The rescored candidates. + output_scores : list of list of float + The rescored scores. + """ + new_scores = topk_scores.copy() + + for k, impl in self.rescorers.items(): + scores = impl.rescore_hyps(topk_candidates) + + index_scores = 0 + for i in range(len(new_scores)): + for j in range(len(new_scores[i])): + new_scores[i][j] += ( + self.weights[k] * scores[index_scores].item() + ) + index_scores += 1 + + sorted_candidates = [ + list( + zip( + *sorted( + zip(sublist, score), key=lambda x: x[1], reverse=True + ) + ) + for sublist, score in zip(topk_candidates, new_scores) + ) + ] + + output_candidates = [] + output_scores = [] + for sublist in sorted_candidates: + for item in sublist: + texts, scores = item + output_candidates.append(list(texts)) + output_scores.append(list(scores)) + + return output_candidates, output_scores + + def _validate_scorer(self, rescorer_names): + """These error messages indicate rescorers are not properly set. + + Arguments + --------- + rescorer_names : list + Prefix of rescorers defined in speechbrain.decoders.scorer. + """ + if len(self.weights) > len(rescorer_names): + raise ValueError( + f"The keys of weights should be named in {rescorer_names}" + ) + + def move_rescorers_to_device(self, device=None): + """Moves rescorers to device. + + Useful to avoid having on GPU rescorers while being + on TRAIN and VALID Stages. + + Arguments + --------- + device : str + The device to be used for scoring. (default: None) + """ + for _, impl in self.rescorers.items(): + impl.to_device(device) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/seq2seq.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/seq2seq.py new file mode 100644 index 000000000..4aefc2d5f --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/seq2seq.py @@ -0,0 +1,2240 @@ +"""Decoding methods for seq2seq autoregressive model. + +Authors + * Adel Moumen 2022, 2023, 2024 + * Ju-Chieh Chou 2020 + * Peter Plantinga 2020 + * Mirco Ravanelli 2020 + * Sung-Lin Yeh 2020 +""" + +from functools import cached_property + +import torch +from torch.distributions import Categorical + +from speechbrain.decoders.utils import ( + _update_mem, + inflate_tensor, + mask_by_condition, +) +from speechbrain.utils.data_utils import undo_padding + + +class AlivedHypotheses(torch.nn.Module): + """This class handle the data for the hypotheses during the decoding. + + Arguments + --------- + alived_seq : torch.Tensor + The sequence of tokens for each hypothesis. + alived_log_probs : torch.Tensor + The log probabilities of each token for each hypothesis. + sequence_scores : torch.Tensor + The sum of log probabilities for each hypothesis. + """ + + def __init__(self, alived_seq, alived_log_probs, sequence_scores): + super().__init__() + self.alived_seq = alived_seq + self.alived_log_probs = alived_log_probs + self.sequence_scores = sequence_scores + + def __getitem__(self, index): + return ( + self.alived_seq[index], + self.alived_log_probs[index], + self.sequence_scores[index], + ) + + def __str__(self): + return f"AlivedHypotheses(alived_seq={self.alived_seq}, alived_log_probs={self.alived_log_probs}, sequence_scores={self.sequence_scores})" + + +class S2SBaseSearcher(torch.nn.Module): + """S2SBaseSearcher class to be inherited by other + decoding approaches for seq2seq model. + + Arguments + --------- + bos_index : int + The index of the beginning-of-sequence (bos) token. + eos_index : int + The index of end-of-sequence (eos) token. + min_decode_ratio : float + The ratio of minimum decoding steps to the length of encoder states. + max_decode_ratio : float + The ratio of maximum decoding steps to the length of encoder states. + """ + + def __init__( + self, bos_index, eos_index, min_decode_ratio, max_decode_ratio + ): + super().__init__() + self.bos_index = bos_index + self.eos_index = eos_index + self.min_decode_ratio = min_decode_ratio + self.max_decode_ratio = max_decode_ratio + + def forward(self, enc_states, wav_len): + """This method should implement the forward algorithm of decoding method. + + Arguments + --------- + enc_states : torch.Tensor + The precomputed encoder states to be used when decoding. + (ex. the encoded speech representation to be attended). + wav_len : torch.Tensor + The speechbrain-style relative length. + + Returns + ------- + hyps + The predicted tokens, as a list of lists or, if return_topk is True, + a Tensor of shape (batch, topk, max length of token_id sequences). + top_lengths + The length of each topk sequence in the batch. + top_scores + This final scores of topk hypotheses. + top_log_probs + The log probabilities of each hypotheses. + """ + raise NotImplementedError + return + + def forward_step( + self, inp_tokens, memory, enc_states, enc_lens, attention_mask=None + ): + """This method should implement one step of + forwarding operation in the autoregressive model. + + Arguments + --------- + inp_tokens : torch.Tensor + The input tensor of the current step. + memory : No limit + The memory variables input for this step. + (ex. RNN hidden states). + enc_states : torch.Tensor + The encoder states to be attended. + enc_lens : torch.Tensor + The actual length of each enc_states sequence. + + Returns + ------- + log_probs : torch.Tensor + Log-probabilities of the current step output. + memory : No limit + The memory variables generated in this step. + (ex. RNN hidden states). + attn : torch.Tensor + The attention weight for doing penalty. + """ + raise NotImplementedError + return + + def reset_mem(self, batch_size, device): + """This method should implement the resetting of + memory variables for the seq2seq model. + E.g., initializing zero vector as initial hidden states. + + Arguments + --------- + batch_size : int + The size of the batch. + device : torch.device + The device to put the initial variables. + + Return + ------ + memory : No limit + The initial memory variable. + """ + raise NotImplementedError + return + + def change_max_decoding_length(self, min_decode_steps, max_decode_steps): + """set the minimum/maximum length of enc_states to be attended.""" + return min_decode_steps, max_decode_steps + + def set_n_out(self): + """set the number of output tokens. + Overrides this function if the fc layer is embedded + in the model, e.g., Whisper. + """ + return self.fc.w.out_features + + def _check_end_condition(self, memory): + """This method is supposed to be overridden by the child class. + For instance, if the decoder has a maximal number of tokens that it can + attend to, this method should return True when the maximal number of tokens + is reached. + """ + return False + + +class S2SGreedySearcher(S2SBaseSearcher): + """This class implements the general forward-pass of + greedy decoding approach. See also S2SBaseSearcher(). + """ + + @torch.no_grad() + def forward(self, enc_states, wav_len, attention_mask=None): + """This method performs a greedy search. + + Arguments + --------- + enc_states : torch.Tensor + The precomputed encoder states to be used when decoding. + (ex. the encoded speech representation to be attended). + wav_len : torch.Tensor + The speechbrain-style relative length. + attention_mask : torch.Tensor + The attention mask to be used when decoding. + + Returns + ------- + hyps : List[List[int]] + List containing the hypotheses. + top_lengths : torch.Tensor (batch) + This tensor contains the length of each hypothesis. + top_scores : torch.Tensor (batch) + The score of each hypotheses. + top_log_probs : torch.Tensor (batch, max length of token_id sequences) + The log probabilities of each hypotheses. + """ + enc_lens = torch.round(enc_states.shape[1] * wav_len).int() + device = enc_states.device + batch_size = enc_states.shape[0] + + memory = self.reset_mem(batch_size, device=device) + + # Using bos as the first input + inp_tokens = ( + enc_states.new_zeros(batch_size).fill_(self.bos_index).long() + ) + + log_probs_lst = [] + min_decode_steps = int(enc_states.shape[1] * self.min_decode_ratio) + max_decode_steps = int(enc_states.shape[1] * self.max_decode_ratio) + + min_decode_steps, max_decode_steps = self.change_max_decoding_length( + min_decode_steps, max_decode_steps + ) + + has_ended = enc_states.new_zeros(batch_size).bool() + for step in range(min_decode_steps, max_decode_steps): + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask, + torch.ones( + batch_size, 1, device=device, dtype=torch.bool + ), + ], + dim=1, + ) + attention_mask[has_ended, -1] = False + + logits, memory, _ = self.forward_step( + inp_tokens, memory, enc_states, enc_lens, attention_mask + ) + + if self.temperature == 0: + inp_tokens = logits.argmax(dim=-1) + else: + inp_tokens = Categorical( + logits=logits / self.temperature + ).sample() + log_probs = torch.nn.functional.log_softmax(logits.float(), dim=-1) + log_probs_lst.append(log_probs) + + has_ended = has_ended | (inp_tokens == self.eos_index) + log_probs[has_ended] = -torch.inf + inp_tokens[has_ended] = self.eos_index + + if has_ended.all() or self._check_end_condition(memory): + break + + log_probs = torch.stack(log_probs_lst, dim=1) + + scores, predictions = log_probs.max(dim=-1) + mask = scores == -torch.inf + scores[mask] = 0 + predictions[mask] = self.eos_index + + ( + top_hyps, + top_lengths, + top_scores, + top_log_probs, + ) = self._get_top_prediction(predictions, scores, log_probs) + + # Convert best hypothesis to list + hyps = undo_padding(top_hyps[:, 0], top_lengths) + + return hyps, top_lengths, top_scores, top_log_probs + + def _get_top_prediction(self, hyps, scores, log_probs): + """This method sorts the scores and return corresponding hypothesis and log probs. + + Arguments + --------- + hyps : torch.Tensor (batch, max length of token_id sequences) + This tensor stores the predicted hypothesis. + scores : torch.Tensor (batch) + The score of each hypotheses. + log_probs : torch.Tensor (batch, max length of token_id sequences) + The log probabilities of each hypotheses. + + Returns + ------- + top_hyps : torch.Tensor (batch, max length of token_id sequences) + This tensor stores the best predicted hypothesis. + top_lengths : torch.Tensor (batch) + This tensor contains the length of each hypothesis. + top_scores : torch.Tensor (batch) + The score of each hypotheses. + top_log_probs : torch.Tensor (batch, max length of token_id sequences) + The log probabilities of each hypotheses. + """ + batch_size = hyps.size(0) + max_length = hyps.size(1) + top_lengths = [max_length] * batch_size + + # Collect lengths of top hyps + for pred_index in range(batch_size): + pred = hyps[pred_index] + pred_length = (pred == self.eos_index).nonzero(as_tuple=False) + if len(pred_length) > 0: + top_lengths[pred_index] = pred_length[0].item() + # Convert lists to tensors + top_lengths = torch.tensor( + top_lengths, dtype=torch.float, device=hyps.device + ) + + # Pick top log probabilities + top_log_probs = log_probs + + # Use SpeechBrain style lengths + top_lengths = top_lengths / max_length + + return ( + hyps.unsqueeze(1), + top_lengths.unsqueeze(1), + scores.unsqueeze(1), + top_log_probs.unsqueeze(1), + ) + + +class S2STransformerGreedySearcher(S2SGreedySearcher): + """This class implements the greedy decoding + for Transformer. + + Arguments + --------- + modules : list with the following one: + model : torch.nn.Module + A TransformerASR model. + seq_lin : torch.nn.Module + A linear output layer for the seq2seq model. + temperature : float + Temperature to use during decoding. + **kwargs + Arguments to pass to S2SGreedySearcher + """ + + def __init__(self, modules, temperature=0.0, **kwargs): + super().__init__(**kwargs) + + self.model = modules[0] + self.fc = modules[1] + self.softmax = torch.nn.LogSoftmax(dim=-1) + + self.temperature = temperature + + def reset_mem(self, batch_size, device): + """Needed to reset the memory during greedy search.""" + return None + + def forward_step( + self, inp_tokens, memory, enc_states, enc_lens, attention_mask=None + ): + """Performs a step in the implemented greedy searcher.""" + memory = _update_mem(inp_tokens, memory) + pred, attn = self.model.decode(memory, enc_states, enc_lens) + logits = self.fc(pred) + return logits[:, -1, :], memory, attn + + +class S2SHuggingFaceLLMGreedySearcher(S2SGreedySearcher): + """This class implements the greedy decoding + for HuggingFace LLM. + + Arguments + --------- + llm_model : torch.nn.Module + A HuggingFace LLM model. + temperature : float + Temperature to use during decoding. + **kwargs + Arguments to pass to S2SGreedySearcher + """ + + def __init__(self, llm_model, temperature=0.6, **kwargs): + super().__init__(**kwargs) + + self.llm_model = llm_model + self.temperature = temperature + self.txt_embedding = llm_model.model.get_input_embeddings() + + def reset_mem(self, batch_size, device): + """Needed to reset the memory during greedy search.""" + return None + + def _update_mem_embeddings(self, inp_tokens, memory): + """This method updates the memory during greedy search.""" + inp_embds = self.txt_embedding(inp_tokens.long()) + if memory is None: + return inp_embds + return torch.cat([memory, inp_embds], dim=1) + + def forward_step( + self, inp_tokens, memory, enc_states, enc_lens, attention_mask + ): + """Performs a step in the implemented greedy searcher.""" + memory = self._update_mem_embeddings(inp_tokens.unsqueeze(-1), memory) + multimodal_embds = torch.cat( + [ + enc_states, + memory, + ], + dim=1, + ) + logits = self.llm_model( + inputs_embeds=multimodal_embds, + attention_mask=attention_mask, + ).logits + return logits[:, -1, :], memory, None + + +class S2SWhisperGreedySearcher(S2SGreedySearcher): + """ + This class implements the greedy decoding + for Whisper neural nets made by OpenAI in + https://cdn.openai.com/papers/whisper.pdf. + + Arguments + --------- + model: HuggingFaceWhisper + The Whisper model. + temperature: float + The temperature to use during decoding. + use_kv_cache: bool (default: True) + Whether to use key-value cache. + suppress_blank: bool (default: True) + This will suppress blank outputs. + suppress_tokens: str or list (default: "-1") + list of tokens ids (or comma-separated token ids) to suppress + "-1" will suppress a set of symbols as defined in `model.non_speech_tokens()` + sample_len: int (default: None) + Maximum number of tokens to sample. + prefix: str or list (default: None) + Prefix to add to the input tokens. + See: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 + prompt: str or list (default: None) + Prompt to add to the input tokens. + See: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 + **kwargs + see S2SBaseSearcher, arguments are directly passed. + """ + + def __init__( + self, + model, + temperature=0.0, + use_kv_cache=True, + suppress_blank=True, + suppress_tokens="-1", + sample_len=None, + prefix=None, + prompt=None, + **kwargs, + ): + super().__init__( + bos_index=model.bos, + eos_index=model.eos, + **kwargs, + ) + self.model = model + self.temperature = temperature + + self.use_kv_cache = use_kv_cache + self.kv_cache = None + self.suppress_blank = suppress_blank + self.suppress_tokens = suppress_tokens + + self.prefix = prefix + self.prompt = prompt + + self.max_attn_tokens = self.model.model.decoder.config.max_length + self.sample_len = sample_len or self.max_attn_tokens // 2 + + self.initial_tokens = self._get_initial_tokens() + self.sample_begin: int = len(self.initial_tokens) + self.eos_index: int = self.model.eos + self.bos_index: int = self.initial_tokens[-1] + + self.no_speech_probs = None + self.lang_tokens = None + + def set_lang_tokens(self, lang_tokens): + """Set the language to be used during decoding.""" + self.lang_tokens = lang_tokens + + def set_task(self, task): + """Set the task to be used during decoding.""" + self.model.set_task(task) + self.initial_tokens = self._get_initial_tokens() + self.sample_begin: int = len(self.initial_tokens) + self.bos_index: int = self.initial_tokens[-1] + + def set_prompt(self, prompt): + """Set the prompt to be used during decoding.""" + self.prompt = prompt + self.initial_tokens = self._get_initial_tokens() + self.sample_begin: int = len(self.initial_tokens) + self.bos_index: int = self.initial_tokens[-1] + + @cached_property + def get_tokens_to_suppress(self): + """Get the tokens to suppress during decoding if self.config.suppress_tokens is None.""" + suppress_tokens = self.suppress_tokens + + if isinstance(suppress_tokens, str): + suppress_tokens = [int(t) for t in suppress_tokens.split(",")] + + if -1 in suppress_tokens: + suppress_tokens = [t for t in suppress_tokens if t >= 0] + suppress_tokens.extend(self.model.non_speech_tokens) + elif suppress_tokens is None or len(suppress_tokens) == 0: + suppress_tokens = [] # interpret empty string as an empty list + else: + assert isinstance(suppress_tokens, list), ( + "suppress_tokens must be a list" + ) + + suppress_tokens.extend( + [ + self.model.transcribe, + self.model.translate, + self.model.bos, + self.model.bos_prev, + self.model.bos_lm, + ] + ) + + return tuple(sorted(set(suppress_tokens))) + + def _get_initial_tokens(self): + """Get the initial tokens to be used during decoding.""" + tokens = self.model.tokenizer.prefix_tokens + prefix = self.prefix + prompt = self.prompt + if prefix: + prefix_tokens = ( + self.model.tokenizer.encode( + " " + prefix.strip(), add_special_tokens=False + ) + if isinstance(prefix, str) + else prefix + ) + if self.sample_len is not None: + max_prefix_len = self.max_attn_tokens // 2 - self.sample_len + prefix_tokens = prefix_tokens[-max_prefix_len:] + tokens = tokens + prefix_tokens + + if prompt: + prompt_tokens = ( + self.model.tokenizer.encode( + " " + prompt.strip(), add_special_tokens=False + ) + if isinstance(prompt, str) + else prompt + ) + tokens = ( + [self.model.bos_prev] + + prompt_tokens[-(self.max_attn_tokens // 2 - 1) :] + + tokens + ) + return tuple(tokens) + + def reset_mem(self, batch_size, device): + """This method set the first tokens to be decoder_input_tokens during search.""" + # reset KV cache + if self.use_kv_cache: + self.kv_cache = None + + self.no_speech_probs = [torch.nan] * batch_size + # the last token will be used as the first input token + # explaining why we are skipping it. + memory_tokens = self.initial_tokens[:-1] + mem = torch.tensor([memory_tokens] * batch_size).to(device) + if self.lang_tokens is not None: + mem[:, self.initial_tokens.index(self.model.bos) + 1] = ( + self.lang_tokens + ) + # after using it, reset it. + self.lang_token = None + return mem + + def forward_step( + self, inp_tokens, memory, enc_states, enc_lens, attention_mask=None + ): + """Performs a step in the implemented beamsearcher.""" + tokens = _update_mem(inp_tokens, memory) + + logits, attn, kv = self.model.forward_decoder( + enc_states, tokens, past_key_values=self.kv_cache + ) + + if tokens.shape[1] == self.sample_begin: + probs_at_bos = ( + logits[:, self.initial_tokens.index(self.model.bos)] + .float() + .softmax(dim=-1) + ) + self.no_speech_probs = probs_at_bos[ + :, self.model.no_speech + ].tolist() + + logits = logits[:, -1] + + if self.use_kv_cache: + self.kv_cache = kv + + if self.suppress_blank: + if tokens.shape[1] == self.sample_begin: + logits[ + :, + self.model.tokenizer.encode(" ", add_special_tokens=False) + + [self.eos_index], + ] = -torch.inf + + if self.suppress_tokens: + if self.model.config.suppress_tokens is None: + tokens_to_suppress = self.get_tokens_to_suppress + else: + tokens_to_suppress = self.model.get_suppress_tokens + logits[:, list(tokens_to_suppress)] = -torch.inf + + return logits, tokens, attn + + def _check_end_condition(self, memory): + """This method checks if the max length is reached.""" + return memory.shape[1] >= self.max_attn_tokens - self.sample_begin + + +class S2SRNNGreedySearcher(S2SGreedySearcher): + """ + This class implements the greedy decoding + for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). + See also S2SBaseSearcher() and S2SGreedySearcher(). + + Arguments + --------- + embedding : torch.nn.Module + An embedding layer. + decoder : torch.nn.Module + Attentional RNN decoder. + linear : torch.nn.Module + A linear output layer. + temperature : float + The temperature to use during decoding. + **kwargs + see S2SBaseSearcher, arguments are directly passed. + + Example + ------- + >>> import speechbrain as sb + >>> from speechbrain.decoders import S2SRNNGreedySearcher + >>> emb = torch.nn.Embedding(5, 3) + >>> dec = sb.nnet.RNN.AttentionalRNNDecoder( + ... "gru", "content", 3, 3, 1, enc_dim=7, input_size=3 + ... ) + >>> lin = sb.nnet.linear.Linear(n_neurons=5, input_size=3) + >>> searcher = S2SRNNGreedySearcher( + ... embedding=emb, + ... decoder=dec, + ... linear=lin, + ... bos_index=0, + ... eos_index=1, + ... min_decode_ratio=0, + ... max_decode_ratio=1, + ... ) + >>> batch_size = 2 + >>> enc = torch.rand([batch_size, 6, 7]) + >>> wav_len = torch.ones([batch_size]) + >>> top_hyps, top_lengths, _, _ = searcher(enc, wav_len) + """ + + def __init__(self, embedding, decoder, linear, temperature=0.0, **kwargs): + super().__init__(**kwargs) + self.emb = embedding + self.dec = decoder + self.fc = linear + self.temperature = temperature + self.softmax = torch.nn.LogSoftmax(dim=-1) + + def reset_mem(self, batch_size, device): + """When doing greedy search, keep hidden state (hs) and context vector (c) + as memory. + """ + hs = None + self.dec.attn.reset() + c = torch.zeros(batch_size, self.dec.attn_dim, device=device) + return hs, c + + def forward_step( + self, inp_tokens, memory, enc_states, enc_lens, attention_mask=None + ): + """Performs a step in the implemented beamsearcher.""" + hs, c = memory + e = self.emb(inp_tokens) + dec_out, hs, c, w = self.dec.forward_step( + e, hs, c, enc_states, enc_lens + ) + logits = self.fc(dec_out) + return logits, (hs, c), w + + +class S2SBeamSearcher(S2SBaseSearcher): + """This class implements the beam-search algorithm for the seq2seq model. + See also S2SBaseSearcher(). + + Arguments + --------- + bos_index : int + The index of beginning-of-sequence token. + eos_index : int + The index of end-of-sequence token. + min_decode_ratio : float + The ratio of minimum decoding steps to length of encoder states. + max_decode_ratio : float + The ratio of maximum decoding steps to length of encoder states. + beam_size : int + The width of beam. + scorer: speechbrain.decoders.scorers.ScorerBuilder + Scorer instance. Default: None. + return_topk : bool + Whether to return topk hypotheses. The topk hypotheses will be + padded to the same length. Default: False. + topk : int + If return_topk is True, then return topk hypotheses. Default: 1. + using_eos_threshold : bool + Whether to use eos threshold. Default: True. + eos_threshold : float + The threshold coefficient for eos token. Default: 1.5. + See 3.1.2 in reference: https://arxiv.org/abs/1904.02619 + length_normalization : bool + Whether to divide the scores by the length. Default: True. + using_max_attn_shift: bool + Whether using the max_attn_shift constraint. Default: False. + max_attn_shift: int + Beam search will block the beams that attention shift more + than max_attn_shift. Default: 60. + Reference: https://arxiv.org/abs/1904.02619 + minus_inf : float + The value of minus infinity to block some path + of the search. Default: -1e20. + """ + + def __init__( + self, + bos_index, + eos_index, + min_decode_ratio, + max_decode_ratio, + beam_size, + scorer=None, + return_topk=False, + topk=1, + using_eos_threshold=True, + eos_threshold=1.5, + length_normalization=True, + using_max_attn_shift=False, + max_attn_shift=60, + minus_inf=-1e20, + ): + super().__init__( + bos_index, eos_index, min_decode_ratio, max_decode_ratio + ) + self.beam_size = beam_size + self.scorer = scorer + self.return_topk = return_topk + self.topk = topk + self.length_normalization = length_normalization + self.using_eos_threshold = using_eos_threshold + self.eos_threshold = eos_threshold + self.using_max_attn_shift = using_max_attn_shift + self.max_attn_shift = max_attn_shift + self.attn_weight = 1.0 + self.ctc_weight = 0.0 + self.minus_inf = minus_inf + + if self.scorer is not None: + # Check length normalization + if length_normalization and self.scorer.weights["length"] > 0.0: + raise ValueError( + "Length normalization is not compatible with length rewarding." + ) + if self.scorer.weights["ctc"] > 0.0: + # Check indices for ctc + all_scorers = { + **self.scorer.full_scorers, + **self.scorer.partial_scorers, + } + blank_index = all_scorers["ctc"].blank_index + if len({bos_index, eos_index, blank_index}) < 3: + raise ValueError( + "Set blank, eos and bos to different indexes for joint ATT/CTC or CTC decoding" + ) + + self.ctc_weight = self.scorer.weights["ctc"] + self.attn_weight = 1.0 - self.ctc_weight + + def _check_full_beams(self, hyps): + """This method checks whether hyps has been full. + + Arguments + --------- + hyps : List + This list contains batch_size number. + Each inside list contains a list stores all the hypothesis for this sentence. + + Returns + ------- + bool + Whether the hyps has been full. + """ + hyps_len = [len(lst) for lst in hyps] + beams_size = [self.beam_size for _ in range(len(hyps_len))] + return hyps_len == beams_size + + def _check_attn_shift(self, attn, prev_attn_peak): + """This method checks whether attention shift is more than attn_shift. + + Arguments + --------- + attn : torch.Tensor + The attention to be checked. + prev_attn_peak : torch.Tensor + The previous attention peak place. + + Returns + ------- + cond : torch.BoolTensor + Each element represents whether the beam is within the max_shift range. + attn_peak : torch.Tensor + The peak of the attn tensor. + """ + # Block the candidates that exceed the max shift + _, attn_peak = torch.max(attn, dim=1) + lt_cond = attn_peak <= (prev_attn_peak + self.max_attn_shift) + mt_cond = attn_peak > (prev_attn_peak - self.max_attn_shift) + + # True if not exceed limit + # Multiplication equals to element-wise and for tensor + cond = (lt_cond * mt_cond).unsqueeze(1) + return cond, attn_peak + + def _check_eos_threshold(self, log_probs): + """This method checks whether eos log-probabilities exceed threshold. + + Arguments + --------- + log_probs : torch.Tensor + The log-probabilities. + + Returns + ------- + cond : torch.BoolTensor + Each element represents whether the eos log-probabilities will be kept. + """ + max_probs, _ = torch.max(log_probs, dim=-1) + eos_probs = log_probs[:, self.eos_index] + cond = eos_probs > (self.eos_threshold * max_probs) + return cond + + def init_hypotheses(self): + """This method initializes the AlivedHypotheses object. + + Returns + ------- + AlivedHypotheses + The alived hypotheses filled with the initial values. + """ + return AlivedHypotheses( + alived_seq=torch.empty(self.n_bh, 0, device=self.device).long(), + alived_log_probs=torch.empty(self.n_bh, 0, device=self.device), + sequence_scores=torch.empty(self.n_bh, device=self.device) + .fill_(float("-inf")) + .index_fill_(0, self.beam_offset, 0.0), + ) + + def _attn_weight_step( + self, inp_tokens, memory, enc_states, enc_lens, attn, log_probs + ): + """This method computes a forward_step if attn_weight is superior to 0. + + Arguments + --------- + inp_tokens : torch.Tensor + The input tensor of the current step. + memory : No limit + The memory variables input for this step. + (ex. RNN hidden states). + enc_states : torch.Tensor + The encoder states to be attended. + enc_lens : torch.Tensor + The actual length of each enc_states sequence. + attn : torch.Tensor + The attention weight. + log_probs : torch.Tensor + The log-probabilities of the current step output. + + Returns + ------- + log_probs : torch.Tensor + Log-probabilities of the current step output. + memory : No limit + The memory variables generated in this step. + (ex. RNN hidden states). + attn : torch.Tensor + The attention weight. + """ + if self.attn_weight > 0: + log_probs, memory, attn = self.forward_step( + inp_tokens, memory, enc_states, enc_lens + ) + log_probs = self.attn_weight * log_probs + return log_probs, memory, attn + + def _max_attn_shift_step(self, attn, prev_attn_peak, log_probs): + """This method will block the beams that attention shift more + than max_attn_shift. + + Arguments + --------- + attn : torch.Tensor + The attention weight. + prev_attn_peak : torch.Tensor + The previous attention peak place. + log_probs : torch.Tensor + The log-probabilities of the current step output. + + Returns + ------- + log_probs : torch.Tensor + Log-probabilities of the current step output. + prev_attn_peak : torch.Tensor + The previous attention peak place. + """ + if self.using_max_attn_shift: + cond, prev_attn_peak = self._check_attn_shift(attn, prev_attn_peak) + log_probs = mask_by_condition( + log_probs, cond, fill_value=self.minus_inf + ) + return log_probs, prev_attn_peak + + def _scorer_step(self, inp_tokens, scorer_memory, attn, log_probs): + """This method call the scorers if scorer is not None. + + Arguments + --------- + inp_tokens : torch.Tensor + The input tensor of the current step. + scorer_memory : No limit + The memory variables input for this step. + (ex. RNN hidden states). + attn : torch.Tensor + The attention weight. + log_probs : torch.Tensor + The log-probabilities of the current step output. + + Returns + ------- + log_probs : torch.Tensor + Log-probabilities of the current step output. + scorer_memory : No limit + The memory variables generated in this step. + """ + if self.scorer is not None: + log_probs, scorer_memory = self.scorer.score( + inp_tokens, scorer_memory, attn, log_probs, self.beam_size + ) + return log_probs, scorer_memory + + def _set_eos_minus_inf_step(self, log_probs, step, min_decode_steps): + """This method set the log_probs of eos to minus infinity if the step is less than min_decode_steps. + + Arguments + --------- + log_probs : torch.Tensor + The log-probabilities of the current step output. + step : int + The current decoding step. + min_decode_steps : int + The minimum decoding steps. + + Returns + ------- + log_probs : torch.Tensor + Log-probabilities of the current step output. + """ + if step < min_decode_steps: + log_probs[:, self.eos_index] = self.minus_inf + return log_probs + + def _eos_threshold_step(self, log_probs): + """This method set the log_probs of eos to minus infinity if the eos log-probabilities is less than eos_threshold. + + Arguments + --------- + log_probs : torch.Tensor + The log-probabilities of the current step output. + + Returns + ------- + log_probs : torch.Tensor + Log-probabilities of the current step output. + """ + if self.using_eos_threshold: + cond = self._check_eos_threshold(log_probs) + log_probs[:, self.eos_index] = mask_by_condition( + log_probs[:, self.eos_index], cond, fill_value=self.minus_inf + ) + return log_probs + + def _attn_weight_permute_memory_step(self, memory, predecessors): + """This method permute the memory if attn_weight is superior to 0. + + Arguments + --------- + memory : No limit + The memory variables input for this step. + (ex. RNN hidden states). + predecessors : torch.Tensor + The index of which beam the current top-K output came from in (t-1) steps. + + Returns + ------- + memory : No limit + The memory variables generated in this step. + (ex. RNN hidden states). + """ + if self.attn_weight > 0: + memory = self.permute_mem(memory, index=predecessors) + return memory + + def _scorer_permute_memory_step( + self, scorer_memory, predecessors, candidates + ): + """This method permute the scorer_memory if scorer is not None. + + Arguments + --------- + scorer_memory : No limit + The memory variables input for this step. + (ex. RNN hidden states). + predecessors : torch.Tensor + The index of which beam the current top-K output came from in (t-1) steps. + candidates : torch.Tensor + The index of the current top-K output. + + Returns + ------- + scorer_memory : No limit + The memory variables generated in this step. + """ + if self.scorer is not None: + scorer_memory = self.scorer.permute_scorer_mem( + scorer_memory, index=predecessors, candidates=candidates + ) + return scorer_memory + + def _max_attn_shift_permute_memory_step(self, prev_attn_peak, predecessors): + """This method permute the prev_attn_peak if using_max_attn_shift is True. + + Arguments + --------- + prev_attn_peak : torch.Tensor + The previous attention peak place. + predecessors : torch.Tensor + The index of which beam the current top-K output came from in (t-1) steps. + + Returns + ------- + prev_attn_peak : torch.Tensor + The previous attention peak place. + """ + if self.using_max_attn_shift: + prev_attn_peak = torch.index_select( + prev_attn_peak, dim=0, index=predecessors + ) + return prev_attn_peak + + def _update_reset_memory(self, enc_states, enc_lens): + """Call reset memory for each module. + + Arguments + --------- + enc_states : torch.Tensor + The encoder states to be attended. + enc_lens : torch.Tensor + The actual length of each enc_states sequence. + + Returns + ------- + memory : No limit + The memory variables generated in this step. + scorer_memory : No limit + The memory variables generated in this step. + """ + memory = self.reset_mem(self.n_bh, device=self.device) + scorer_memory = None + if self.scorer is not None: + scorer_memory = self.scorer.reset_scorer_mem(enc_states, enc_lens) + return memory, scorer_memory + + def _update_permute_memory( + self, memory, scorer_memory, predecessors, candidates, prev_attn_peak + ): + """Call permute memory for each module. It allows us to synchronize the memory with the output. + + Arguments + --------- + memory : No limit + The memory variables input for this step. + (ex. RNN hidden states). + scorer_memory : No limit + The memory variables input for this step. + (ex. RNN hidden states). + predecessors : torch.Tensor + The index of which beam the current top-K output came from in (t-1) steps. + candidates : torch.Tensor + The index of the current top-K output. + prev_attn_peak : torch.Tensor + The previous attention peak place. + + Returns + ------- + memory : No limit + The memory variables generated in this step. + scorer_memory : No limit + The memory variables generated in this step. + prev_attn_peak : torch.Tensor + The previous attention peak place. + """ + memory = self._attn_weight_permute_memory_step(memory, predecessors) + + scorer_memory = self._scorer_permute_memory_step( + scorer_memory, predecessors, candidates + ) + + # If using_max_attn_shift, then the previous attn peak has to be permuted too. + prev_attn_peak = self._max_attn_shift_permute_memory_step( + prev_attn_peak, predecessors + ) + + return memory, scorer_memory, prev_attn_peak + + def _update_sequences_and_log_probs( + self, log_probs, inp_tokens, predecessors, candidates, alived_hyps + ): + """This method update sequences and log probabilities by adding the new inp_tokens. + + Arguments + --------- + log_probs : torch.Tensor + The log-probabilities of the current step output. + inp_tokens : torch.Tensor + The input tensor of the current step. + predecessors : torch.Tensor + The index of which beam the current top-K output came from in (t-1) steps. + candidates : torch.Tensor + The index of the current top-K output. + alived_hyps : AlivedHypotheses + The alived hypotheses. + + Returns + ------- + alived_hyps : AlivedHypotheses + The alived hypotheses. + """ + # Update alived_seq + alived_hyps.alived_seq = torch.cat( + [ + torch.index_select( + alived_hyps.alived_seq, dim=0, index=predecessors + ), + inp_tokens.unsqueeze(1), + ], + dim=-1, + ) + + # Takes the log-probabilities + beam_log_probs = log_probs[ + torch.arange(self.batch_size).unsqueeze(1), candidates + ].reshape(self.n_bh) + + # Update alived_log_probs + alived_hyps.alived_log_probs = torch.cat( + [ + torch.index_select( + alived_hyps.alived_log_probs, dim=0, index=predecessors + ), + beam_log_probs.unsqueeze(1), + ], + dim=-1, + ) + + return alived_hyps + + def _compute_scores_and_next_inp_tokens(self, alived_hyps, log_probs, step): + """Compute scores and next input tokens. + + Arguments + --------- + alived_hyps : AlivedHypotheses + The alived hypotheses. + log_probs : torch.Tensor + The log-probabilities of the current step output. + step : int + The current decoding step. + + Returns + ------- + scores : torch.Tensor + The scores of the current step output. + candidates : torch.Tensor + The index of the current top-K output. + predecessors : torch.Tensor + The index of which beam the current top-K output came from in (t-1) steps. + inp_tokens : torch.Tensor + The input tensor of the current step. + alived_hyps : AlivedHypotheses + The alived hypotheses. + """ + scores = alived_hyps.sequence_scores.unsqueeze(1).expand(-1, self.n_out) + scores = scores + log_probs + + # length normalization + if self.length_normalization: + scores = scores / (step + 1) + + # keep topk beams + scores, candidates = scores.view(self.batch_size, -1).topk( + self.beam_size, dim=-1 + ) + + # The input for the next step, also the output of current step. + inp_tokens = (candidates % self.n_out).view(self.n_bh) + + scores = scores.view(self.n_bh) + alived_hyps.sequence_scores = scores + + # recover the length normalization + if self.length_normalization: + alived_hyps.sequence_scores = alived_hyps.sequence_scores * ( + step + 1 + ) + + # The index of which beam the current top-K output came from in (t-1) steps. + predecessors = ( + torch.div(candidates, self.n_out, rounding_mode="floor") + + self.beam_offset.unsqueeze(1).expand_as(candidates) + ).view(self.n_bh) + + return ( + scores, + candidates, + predecessors, + inp_tokens, + alived_hyps, + ) + + def init_beam_search_data(self, enc_states, wav_len): + """Initialize the beam search data. + + Arguments + --------- + enc_states : torch.Tensor + The encoder states to be attended. + wav_len : torch.Tensor + The actual length of each enc_states sequence. + + Returns + ------- + alived_hyps : AlivedHypotheses + The alived hypotheses. + inp_tokens : torch.Tensor + The input tensor of the current step. + log_probs : torch.Tensor + The log-probabilities of the current step output. + eos_hyps_and_log_probs_scores : list + Generated hypotheses (the ones that have reached eos) and log probs scores. + memory : No limit + The memory variables generated in this step. + scorer_memory : No limit + The memory variables generated in this step. + attn : torch.Tensor + The attention weight. + prev_attn_peak : torch.Tensor + The previous attention peak place. + enc_states : torch.Tensor + The encoder states to be attended. + enc_lens : torch.Tensor + The actual length of each enc_states sequence. + """ + enc_lens = torch.round(enc_states.shape[1] * wav_len).int() + + self.device = enc_states.device + self.batch_size = enc_states.shape[0] + self.n_bh = self.batch_size * self.beam_size + + self.n_out = self.set_n_out() + + memory, scorer_memory = self._update_reset_memory(enc_states, enc_lens) + + # Inflate the enc_states and enc_len by beam_size times + enc_states = inflate_tensor(enc_states, times=self.beam_size, dim=0) + enc_lens = inflate_tensor(enc_lens, times=self.beam_size, dim=0) + + # Using bos as the first input + inp_tokens = ( + torch.zeros(self.n_bh, device=self.device) + .fill_(self.bos_index) + .long() + ) + + # The first index of each sentence. + self.beam_offset = ( + torch.arange(self.batch_size, device=self.device) * self.beam_size + ) + + # initialize sequence scores variables. + sequence_scores = torch.empty(self.n_bh, device=self.device).fill_( + self.minus_inf + ) + + # keep only the first to make sure no redundancy. + sequence_scores.index_fill_(0, self.beam_offset, 0.0) + + # keep the hypothesis that reaches eos and their corresponding score and log_probs. + eos_hyps_and_log_probs_scores = [[] for _ in range(self.batch_size)] + + self.min_decode_steps = int(enc_states.shape[1] * self.min_decode_ratio) + self.max_decode_steps = int(enc_states.shape[1] * self.max_decode_ratio) + + # the decoding steps can be based on the max number of tokens that a decoder can process + # (e.g., 448 for Whisper). + ( + self.min_decode_steps, + self.max_decode_steps, + ) = self.change_max_decoding_length( + self.min_decode_steps, self.max_decode_steps + ) + + # Initialize the previous attention peak to zero + # This variable will be used when using_max_attn_shift=True + prev_attn_peak = torch.zeros(self.n_bh, device=self.device) + attn = None + + log_probs = torch.full((self.n_bh, self.n_out), 0.0, device=self.device) + + alived_hyps = self.init_hypotheses() + + return ( + alived_hyps, + inp_tokens, + log_probs, + eos_hyps_and_log_probs_scores, + memory, + scorer_memory, + attn, + prev_attn_peak, + enc_states, + enc_lens, + ) + + def _update_hyps_and_scores_if_eos_token( + self, inp_tokens, alived_hyps, eos_hyps_and_log_probs_scores, scores + ): + """This method will update hyps and scores if inp_tokens are eos. + + Arguments + --------- + inp_tokens : torch.Tensor + The current output. + alived_hyps : AlivedHypotheses + alived_seq : torch.Tensor + alived_log_probs : torch.Tensor + eos_hyps_and_log_probs_scores : list + Generated hypotheses (the ones that have reached eos) and log probs scores. + scores : torch.Tensor + Scores at the current step. + + Returns + ------- + is_eos : torch.BoolTensor + Each element represents whether the token is eos. + """ + is_eos = inp_tokens.eq(self.eos_index) + (eos_indices,) = torch.nonzero(is_eos, as_tuple=True) + + # Store the hypothesis and their scores when reaching eos. + if eos_indices.shape[0] > 0: + for index in eos_indices: + # convert to int + index = index.item() + batch_id = torch.div( + index, self.beam_size, rounding_mode="floor" + ) + if ( + len(eos_hyps_and_log_probs_scores[batch_id]) + == self.beam_size + ): + continue + hyp = alived_hyps.alived_seq[index, :] + log_probs = alived_hyps.alived_log_probs[index, :] + final_scores = scores[index].clone() + eos_hyps_and_log_probs_scores[batch_id].append( + (hyp, log_probs, final_scores) + ) + + return is_eos + + def _get_topk_prediction(self, eos_hyps_and_log_probs_scores): + """This method sorts the scores and return corresponding hypothesis and log probs. + + Arguments + --------- + eos_hyps_and_log_probs_scores : list + Generated hypotheses (the ones that have reached eos) and log probs scores. + + Returns + ------- + topk_hyps : torch.Tensor (batch, topk, max length of token_id sequences) + This tensor stores the topk predicted hypothesis. + topk_lengths : torch.Tensor (batch, topk) + This tensor contains the final scores of topk hypotheses. + topk_scores : torch.Tensor (batch, topk) + The length of each topk sequence in the batch. + topk_log_probs : torch.Tensor (batch, topk, max length of token_id sequences) + The log probabilities of each hypotheses. + """ + top_hyps, top_log_probs, top_scores, top_lengths = [], [], [], [] + batch_size = len(eos_hyps_and_log_probs_scores) + + # Collect hypotheses + for i in range(len(eos_hyps_and_log_probs_scores)): + hyps, log_probs, scores = zip(*eos_hyps_and_log_probs_scores[i]) + top_hyps += hyps + top_scores += scores + top_log_probs += log_probs + top_lengths += [len(hyp) for hyp in hyps] + + # Convert lists to tensors + top_hyps = torch.nn.utils.rnn.pad_sequence( + top_hyps, batch_first=True, padding_value=0 + ) + top_log_probs = torch.nn.utils.rnn.pad_sequence( + top_log_probs, batch_first=True, padding_value=0 + ) + top_lengths = torch.tensor( + top_lengths, dtype=torch.float, device=top_hyps.device + ) + top_scores = torch.stack((top_scores), dim=0).view(batch_size, -1) + + # Use SpeechBrain style lengths + top_lengths = (top_lengths - 1) / top_hyps.size(1) + + # Get topk indices + topk_scores, indices = top_scores.topk(self.topk, dim=-1) + indices = (indices + self.beam_offset.unsqueeze(1)).view( + batch_size * self.topk + ) + # Select topk hypotheses + topk_hyps = torch.index_select(top_hyps, dim=0, index=indices) + topk_hyps = topk_hyps.view(batch_size, self.topk, -1) + topk_lengths = torch.index_select(top_lengths, dim=0, index=indices) + topk_lengths = topk_lengths.view(batch_size, self.topk) + topk_log_probs = torch.index_select(top_log_probs, dim=0, index=indices) + topk_log_probs = topk_log_probs.view(batch_size, self.topk, -1) + + return topk_hyps, topk_lengths, topk_scores, topk_log_probs + + def search_step( + self, + alived_hyps, + inp_tokens, + log_probs, + eos_hyps_and_log_probs_scores, + memory, + scorer_memory, + attn, + prev_attn_peak, + enc_states, + enc_lens, + step, + ): + """A search step for the next most likely tokens. + + Arguments + --------- + alived_hyps : AlivedHypotheses + The alived hypotheses. + inp_tokens : torch.Tensor + The input tensor of the current step. + log_probs : torch.Tensor + The log-probabilities of the current step output. + eos_hyps_and_log_probs_scores : list + Generated hypotheses (the ones that have reached eos) and log probs scores. + memory : No limit + The memory variables input for this step. + (ex. RNN hidden states). + scorer_memory : No limit + The memory variables input for this step. + (ex. RNN hidden states). + attn : torch.Tensor + The attention weight. + prev_attn_peak : torch.Tensor + The previous attention peak place. + enc_states : torch.Tensor + The encoder states to be attended. + enc_lens : torch.Tensor + The actual length of each enc_states sequence. + step : int + The current decoding step. + + Returns + ------- + alived_hyps : AlivedHypotheses + The alived hypotheses. + inp_tokens : torch.Tensor + The input tensor of the current step. + log_probs : torch.Tensor + The log-probabilities of the current step output. + eos_hyps_and_log_probs_scores : list + Generated hypotheses (the ones that have reached eos) and log probs scores. + memory : No limit + The memory variables generated in this step. + scorer_memory : No limit + The memory variables generated in this step. + attn : torch.Tensor + The attention weight. + prev_attn_peak : torch.Tensor + The previous attention peak place. + scores : torch.Tensor + The scores of the current step output. + """ + (log_probs, memory, attn) = self._attn_weight_step( + inp_tokens, memory, enc_states, enc_lens, attn, log_probs + ) + + # Keep the original value + log_probs_clone = log_probs.clone().reshape(self.batch_size, -1) + + (log_probs, prev_attn_peak) = self._max_attn_shift_step( + attn, prev_attn_peak, log_probs + ) + + log_probs = self._set_eos_minus_inf_step( + log_probs, step, self.min_decode_steps + ) + + log_probs = self._eos_threshold_step(log_probs) + + (log_probs, scorer_memory) = self._scorer_step( + inp_tokens, scorer_memory, attn, log_probs + ) + + ( + scores, + candidates, + predecessors, + inp_tokens, + alived_hyps, + ) = self._compute_scores_and_next_inp_tokens( + alived_hyps, log_probs, step + ) + + memory, scorer_memory, prev_attn_peak = self._update_permute_memory( + memory, scorer_memory, predecessors, candidates, prev_attn_peak + ) + + alived_hyps = self._update_sequences_and_log_probs( + log_probs_clone, inp_tokens, predecessors, candidates, alived_hyps + ) + + is_eos = self._update_hyps_and_scores_if_eos_token( + inp_tokens, alived_hyps, eos_hyps_and_log_probs_scores, scores + ) + + # Block the paths that have reached eos. + alived_hyps.sequence_scores.masked_fill_(is_eos, float("-inf")) + + return ( + alived_hyps, + inp_tokens, + log_probs, + eos_hyps_and_log_probs_scores, + memory, + scorer_memory, + attn, + prev_attn_peak, + scores, + ) + + def _fill_alived_hyps_with_eos_token( + self, alived_hyps, eos_hyps_and_log_probs_scores, scores + ): + """Fill the alived_hyps that have not reached eos with eos. + + Arguments + --------- + alived_hyps : AlivedHypotheses + The alived hypotheses. + eos_hyps_and_log_probs_scores : list + Generated hypotheses (the ones that have reached eos) and log probs scores. + scores : torch.Tensor + The scores of the current step output. + + Returns + ------- + eos_hyps_and_log_probs_scores : list + Generated hypotheses (the ones that have reached eos) and log probs scores. + """ + if not self._check_full_beams(eos_hyps_and_log_probs_scores): + # Using all eos to fill-up the hyps. + inp_tokens = ( + torch.zeros(self.n_bh, device=self.device) + .fill_(self.eos_index) + .long() + ) + self._update_hyps_and_scores_if_eos_token( + inp_tokens, alived_hyps, eos_hyps_and_log_probs_scores, scores + ) + + return eos_hyps_and_log_probs_scores + + def forward(self, enc_states, wav_len): # noqa: C901 + """Applies beamsearch and returns the predicted tokens. + + Arguments + --------- + enc_states : torch.Tensor + The encoder states to be attended. + wav_len : torch.Tensor + The actual length of each enc_states sequence. + + Returns + ------- + hyps : list + The predicted tokens. + best_lens : torch.Tensor + The length of each predicted tokens. + best_scores : torch.Tensor + The scores of each predicted tokens. + best_log_probs : torch.Tensor + The log probabilities of each predicted tokens. + """ + ( + alived_hyps, + inp_tokens, + log_probs, + eos_hyps_and_log_probs_scores, + memory, + scorer_memory, + attn, + prev_attn_peak, + enc_states, + enc_lens, + ) = self.init_beam_search_data(enc_states, wav_len) + + for step in range(self.max_decode_steps): + # terminate condition + if self._check_full_beams(eos_hyps_and_log_probs_scores): + break + + ( + alived_hyps, + inp_tokens, + log_probs, + eos_hyps_and_log_probs_scores, + memory, + scorer_memory, + attn, + prev_attn_peak, + scores, + ) = self.search_step( + alived_hyps, + inp_tokens, + log_probs, + eos_hyps_and_log_probs_scores, + memory, + scorer_memory, + attn, + prev_attn_peak, + enc_states, + enc_lens, + step, + ) + + if self._check_end_condition(alived_hyps): + break + + finals_hyps_and_log_probs_scores = ( + self._fill_alived_hyps_with_eos_token( + alived_hyps, eos_hyps_and_log_probs_scores, scores + ) + ) + + ( + topk_hyps, + topk_lengths, + topk_scores, + topk_log_probs, + ) = self._get_topk_prediction(finals_hyps_and_log_probs_scores) + + if self.return_topk: + return topk_hyps, topk_lengths, topk_scores, topk_log_probs + else: + # select the best hyps + best_hyps = topk_hyps[:, 0, :] + best_lens = topk_lengths[:, 0] + best_scores = topk_scores[:, 0] + best_log_probs = topk_log_probs[:, 0, :] + + # Convert best hypothesis to list + hyps = undo_padding(best_hyps, best_lens) + + return hyps, best_lens, best_scores, best_log_probs + + def _check_end_condition(self, alived_hyps): + """This method is supposed to be overridden by the child class. + For instance, if the decoder has a maximal number of tokens that it can + attend to, this method should return True when the maximal number of tokens + is reached. + """ + return False + + def permute_mem(self, memory, index): + """This method permutes the seq2seq model memory + to synchronize the memory index with the current output. + + Arguments + --------- + memory : No limit + The memory variable to be permuted. + index : torch.Tensor + The index of the previous path. + + Returns + ------- + The variable of the memory being permuted. + """ + raise NotImplementedError + return + + +class S2SRNNBeamSearcher(S2SBeamSearcher): + """ + This class implements the beam search decoding + for AttentionalRNNDecoder (speechbrain/nnet/RNN.py). + See also S2SBaseSearcher(), S2SBeamSearcher(). + + Arguments + --------- + embedding : torch.nn.Module + An embedding layer. + decoder : torch.nn.Module + Attentional RNN decoder. + linear : torch.nn.Module + A linear output layer. + temperature : float + Temperature factor applied to softmax. It changes the probability + distribution, being softer when T>1 and sharper with T<1. + **kwargs + see S2SBeamSearcher, arguments are directly passed. + + Example + ------- + >>> import speechbrain as sb + >>> vocab_size = 5 + >>> emb = torch.nn.Embedding(vocab_size, 3) + >>> dec = sb.nnet.RNN.AttentionalRNNDecoder( + ... "gru", "content", 3, 3, 1, enc_dim=7, input_size=3 + ... ) + >>> lin = sb.nnet.linear.Linear(n_neurons=vocab_size, input_size=3) + >>> coverage_scorer = sb.decoders.scorer.CoverageScorer(vocab_size) + >>> scorer = sb.decoders.scorer.ScorerBuilder( + ... full_scorers=[coverage_scorer], + ... partial_scorers=[], + ... weights=dict(coverage=1.5), + ... ) + >>> searcher = S2SRNNBeamSearcher( + ... embedding=emb, + ... decoder=dec, + ... linear=lin, + ... bos_index=4, + ... eos_index=4, + ... min_decode_ratio=0, + ... max_decode_ratio=1, + ... beam_size=2, + ... scorer=scorer, + ... ) + >>> batch_size = 2 + >>> enc = torch.rand([batch_size, 6, 7]) + >>> wav_len = torch.ones([batch_size]) + >>> hyps, _, _, _ = searcher(enc, wav_len) + """ + + def __init__(self, embedding, decoder, linear, temperature=1.0, **kwargs): + super().__init__(**kwargs) + self.emb = embedding + self.dec = decoder + self.fc = linear + self.softmax = torch.nn.LogSoftmax(dim=-1) + self.temperature = temperature + + def reset_mem(self, batch_size, device): + """Needed to reset the memory during beamsearch.""" + hs = None + self.dec.attn.reset() + c = torch.zeros(batch_size, self.dec.attn_dim, device=device) + return hs, c + + def forward_step(self, inp_tokens, memory, enc_states, enc_lens): + """Performs a step in the implemented beamsearcher.""" + with torch.no_grad(): + hs, c = memory + e = self.emb(inp_tokens) + dec_out, hs, c, w = self.dec.forward_step( + e, hs, c, enc_states, enc_lens + ) + log_probs = self.softmax(self.fc(dec_out) / self.temperature) + # average attn weight of heads when attn_type is multiheadlocation + if self.dec.attn_type == "multiheadlocation": + w = torch.mean(w, dim=1) + return log_probs, (hs, c), w + + def permute_mem(self, memory, index): + """Memory permutation during beamsearch.""" + hs, c = memory + + # shape of hs: [num_layers, batch_size, n_neurons] + if isinstance(hs, tuple): + hs_0 = torch.index_select(hs[0], dim=1, index=index) + hs_1 = torch.index_select(hs[1], dim=1, index=index) + hs = (hs_0, hs_1) + else: + hs = torch.index_select(hs, dim=1, index=index) + + c = torch.index_select(c, dim=0, index=index) + if self.dec.attn_type == "location": + self.dec.attn.prev_attn = torch.index_select( + self.dec.attn.prev_attn, dim=0, index=index + ) + return (hs, c) + + +class S2STransformerBeamSearcher(S2SBeamSearcher): + """This class implements the beam search decoding + for Transformer. + See also S2SBaseSearcher(), S2SBeamSearcher(). + + Arguments + --------- + modules : list with the following one: + model : torch.nn.Module + A Transformer model. + seq_lin : torch.nn.Module + A linear output layer. + temperature : float + Temperature factor applied to softmax. It changes the probability + distribution, being softer when T>1 and sharper with T<1. + **kwargs + Arguments to pass to S2SBeamSearcher + + Example + ------- + >>> from speechbrain.nnet.linear import Linear + >>> from speechbrain.lobes.models.transformer.TransformerASR import ( + ... TransformerASR, + ... ) + >>> from speechbrain.decoders import S2STransformerBeamSearcher + >>> batch_size = 8 + >>> n_channels = 6 + >>> input_size = 40 + >>> d_model = 128 + >>> tgt_vocab = 140 + >>> src = torch.rand([batch_size, n_channels, input_size]) + >>> tgt = torch.randint(0, tgt_vocab, [batch_size, n_channels]) + >>> net = TransformerASR( + ... tgt_vocab, + ... input_size, + ... d_model, + ... 8, + ... 1, + ... 1, + ... 1024, + ... activation=torch.nn.GELU, + ... ) + >>> ctc_lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab) + >>> lin = Linear(input_shape=(1, 40, d_model), n_neurons=tgt_vocab) + >>> searcher = S2STransformerBeamSearcher( + ... modules=[net, lin], + ... bos_index=1, + ... eos_index=2, + ... min_decode_ratio=0.0, + ... max_decode_ratio=1.0, + ... using_eos_threshold=False, + ... beam_size=7, + ... temperature=1.15, + ... ) + >>> enc, dec = net.forward(src, tgt) + >>> hyps, _, _, _ = searcher(enc, torch.ones(batch_size)) + """ + + def __init__(self, modules, temperature=1.0, **kwargs): + super().__init__(**kwargs) + + self.model = modules[0] + self.fc = modules[1] + self.softmax = torch.nn.LogSoftmax(dim=-1) + + self.temperature = temperature + + def reset_mem(self, batch_size, device): + """Needed to reset the memory during beamsearch.""" + return None + + def permute_mem(self, memory, index): + """Memory permutation during beamsearch.""" + memory = torch.index_select(memory, dim=0, index=index) + return memory + + def forward_step(self, inp_tokens, memory, enc_states, enc_lens): + """Performs a step in the implemented beamsearcher.""" + memory = _update_mem(inp_tokens, memory) + pred, attn = self.model.decode(memory, enc_states, enc_lens) + prob_dist = self.softmax(self.fc(pred) / self.temperature) + return prob_dist[:, -1, :], memory, attn + + +class S2SWhisperBeamSearcher(S2SBeamSearcher): + """This class implements the beam search decoding + for Whisper neural nets made by OpenAI in + https://cdn.openai.com/papers/whisper.pdf. + + The beam search is stateful, meaning that some variables are stored + in the searcher. If you want to reuse the searcher in different + contexts, you should make sure that the variables are updated + accordingly. + + Arguments + --------- + module : list with the following one: + model : torch.nn.Module + A whisper model. It should have a decode() method. + temperature: float + The temperature to use during decoding. + use_kv_cache: bool (default: True) + Whether to use key-value cache. + suppress_blank: bool (default: True) + This will suppress blank outputs. + suppress_tokens: str or list (default: "-1") + list of tokens ids (or comma-separated token ids) to suppress + "-1" will suppress a set of symbols as defined in `model.non_speech_tokens()` + sample_len: int (default: None) + Maximum number of tokens to sample. + prefix: str or list (default: None) + Prefix to add to the input tokens. + See: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 + prompt: str or list (default: None) + Prompt to add to the input tokens. + See: https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 + **kwargs + see S2SBeamSearcher, arguments are directly passed. + """ + + def __init__( + self, + module, + temperature=1.0, + use_kv_cache=True, + suppress_blank=True, + suppress_tokens="-1", + sample_len=None, + prefix=None, + prompt=None, + **kwargs, + ): + super().__init__( + bos_index=module[0].bos, + eos_index=module[0].eos, + **kwargs, + ) + + self.model = module[0] + self.temperature = temperature + self.use_kv_cache = use_kv_cache + self.kv_cache = None + self.suppress_blank = suppress_blank + self.suppress_tokens = suppress_tokens + + self.prefix = prefix + self.prompt = prompt + + self.max_attn_tokens = self.model.model.decoder.config.max_length + self.sample_len = sample_len or self.max_attn_tokens // 2 + + self.initial_tokens = self._get_initial_tokens() + self.sample_begin: int = len(self.initial_tokens) + self.eos_index: int = self.model.eos + self.bos_index: int = self.initial_tokens[-1] + + self.no_speech_probs = None + self.lang_tokens = None + + def set_lang_tokens(self, lang_tokens): + """Set the language to be used during decoding.""" + self.lang_tokens = lang_tokens + + def set_task(self, task): + """Set the task to be used during decoding.""" + self.model.set_task(task) + self.initial_tokens = self._get_initial_tokens() + self.sample_begin: int = len(self.initial_tokens) + self.bos_index: int = self.initial_tokens[-1] + + def set_prompt(self, prompt): + """Set the prompt to be used during decoding.""" + self.prompt = prompt + self.initial_tokens = self._get_initial_tokens() + self.sample_begin: int = len(self.initial_tokens) + self.bos_index: int = self.initial_tokens[-1] + + @cached_property + def get_tokens_to_suppress(self): + """Get the tokens to suppress during decoding if self.config.suppress_tokens is None.""" + suppress_tokens = self.suppress_tokens + + if isinstance(suppress_tokens, str): + suppress_tokens = [int(t) for t in suppress_tokens.split(",")] + + if -1 in suppress_tokens: + suppress_tokens = [t for t in suppress_tokens if t >= 0] + suppress_tokens.extend(self.model.non_speech_tokens) + elif suppress_tokens is None or len(suppress_tokens) == 0: + suppress_tokens = [] # interpret empty string as an empty list + else: + assert isinstance(suppress_tokens, list), ( + "suppress_tokens must be a list" + ) + + suppress_tokens.extend( + [ + self.model.transcribe, + self.model.translate, + self.model.bos, + self.model.bos_prev, + self.model.bos_lm, + ] + ) + + return tuple(sorted(set(suppress_tokens))) + + def _get_initial_tokens(self): + """Get the initial tokens to be used during decoding.""" + tokens = self.model.tokenizer.prefix_tokens + prefix = self.prefix + prompt = self.prompt + if prefix: + prefix_tokens = ( + self.model.tokenizer.encode( + " " + prefix.strip(), add_special_tokens=False + ) + if isinstance(prefix, str) + else prefix + ) + if self.sample_len is not None: + max_prefix_len = self.max_attn_tokens // 2 - self.sample_len + prefix_tokens = prefix_tokens[-max_prefix_len:] + tokens = tokens + prefix_tokens + + if prompt: + prompt_tokens = ( + self.model.tokenizer.encode( + " " + prompt.strip(), add_special_tokens=False + ) + if isinstance(prompt, str) + else prompt + ) + tokens = ( + [self.model.bos_prev] + + prompt_tokens[-(self.max_attn_tokens // 2 - 1) :] + + tokens + ) + return tuple(tokens) + + def reset_mem(self, batch_size, device): + """This method set the first tokens to be decoder_input_tokens during search.""" + # reset KV cache + if self.use_kv_cache: + self.kv_cache = None + + self.no_speech_probs = [torch.nan] * batch_size + + # the last token will be used as the first input token + # explaining why we are skipping it. + memory_tokens = self.initial_tokens[:-1] + mem = torch.tensor([memory_tokens] * batch_size).to(device) + if self.lang_tokens is not None: + mem[:, self.initial_tokens.index(self.model.bos) + 1] = ( + self.lang_tokens + ) + # after using it, reset it. + self.lang_token = None + return mem + + def permute_mem(self, memory, index): + """Permutes the memory.""" + memory = torch.index_select(memory, dim=0, index=index) + # if using kv_cache, we need to permute the kv_cache as well + if self.use_kv_cache: + self.kv_cache = self._reorder_cache(self.kv_cache, index) + return memory + + def _reorder_cache(self, past_key_values, beam_idx): + """Reorder the key-value cache. + + Arguments + --------- + past_key_values : tuple + The key-value cache. + beam_idx : torch.Tensor + The index of the previous path. + + Returns + ------- + The reordered key-value cache. + """ + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) + for past_state in layer_past + ), + ) + return reordered_past + + def set_n_out(self): + """set the number of output tokens.""" + return self.model.model.decoder.embed_tokens.weight.shape[0] + + def forward_step(self, inp_tokens, memory, enc_states, enc_lens): + """Performs a step in the implemented beamsearcher.""" + tokens = _update_mem(inp_tokens, memory) + + logits, attn, kv = self.model.forward_decoder( + enc_states, tokens, past_key_values=self.kv_cache + ) + + if tokens.shape[1] == self.sample_begin: + probs_at_bos = ( + logits[:, self.initial_tokens.index(self.model.bos)] + .float() + .softmax(dim=-1) + ) + self.no_speech_probs = probs_at_bos[ + :, self.model.no_speech + ].tolist() + + logits = logits[:, -1] + + if self.use_kv_cache: + self.kv_cache = kv + + if self.suppress_blank: + if tokens.shape[1] == self.sample_begin: + logits[ + :, + self.model.tokenizer.encode(" ", add_special_tokens=False) + + [self.eos_index], + ] = -torch.inf + + if self.suppress_tokens: + if self.model.config.suppress_tokens is None: + tokens_to_suppress = self.get_tokens_to_suppress + else: + tokens_to_suppress = self.model.get_suppress_tokens + logits[:, list(tokens_to_suppress)] = -torch.inf + + log_probs = ( + torch.nn.functional.log_softmax(logits.float(), dim=-1) + / self.temperature + ) + + return log_probs, tokens, attn + + def _check_end_condition(self, alived_hyps): + """This method checks if the max length is reached.""" + return ( + alived_hyps.alived_seq.shape[1] + >= self.max_attn_tokens - self.sample_begin + ) + + +class S2SHFTextBasedBeamSearcher(S2STransformerBeamSearcher): + """This class implements the beam search decoding + for the text-based HF seq2seq models, such as mBART or NLLB. + It is NOT significantly different from S2STransformerBeamSearcher. + This is why it inherits S2STransformerBeamSearcher. + The main difference might arise when one wishes to use directly + the lm_head of the text-based HF model rather than making a new + projection layer (self.fc = None). + + Arguments + --------- + modules : list with the following one: + model : torch.nn.Module + A Transformer model. + seq_lin : torch.nn.Module + A linear output layer. + Normally set to None for this usecase. + vocab_size : int + The dimension of the lm_head. + **kwargs + Arguments to pass to S2SBeamSearcher + """ + + def __init__(self, modules, vocab_size, **kwargs): + super().__init__(modules, **kwargs) + self.vocab_size = vocab_size + + def forward_step(self, inp_tokens, memory, enc_states, enc_lens): + """Performs a step in the implemented beamsearcher.""" + memory = _update_mem(inp_tokens, memory) + pred, attn = self.model.decode(memory, enc_states, enc_lens) + if self.fc is not None: + pred = self.fc(pred) + prob_dist = self.softmax(pred / self.temperature) + return prob_dist[:, -1, :], memory, attn + + def set_n_out(self): + """set the number of output tokens.""" + return self.vocab_size diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/transducer.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/transducer.py new file mode 100644 index 000000000..a4c8b3ffc --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/transducer.py @@ -0,0 +1,648 @@ +"""Decoders and output normalization for Transducer sequence. + +Author: + Abdelwahab HEBA 2020 + Sung-Lin Yeh 2020 +""" + +from dataclasses import dataclass +from functools import partial +from typing import Any, Optional + +import torch + + +@dataclass +class TransducerGreedySearcherStreamingContext(torch.nn.Module): + """Simple wrapper for the hidden state of the transducer greedy searcher. + Used by :meth:`~TransducerBeamSearcher.transducer_greedy_decode_streaming`. + """ + + hidden: Optional[Any] = None + """Hidden state; typically a tensor or a tuple of tensors.""" + + +class TransducerBeamSearcher(torch.nn.Module): + """ + This class implements the beam-search algorithm for the transducer model. + + Arguments + --------- + decode_network_lst : list + List of prediction network (PN) layers. + tjoint: transducer_joint module + This module perform the joint between TN and PN. + classifier_network : list + List of output layers (after performing joint between TN and PN) + exp: (TN,PN) => joint => classifier_network_list [DNN block, Linear..] => chars prob + blank_id : int + The blank symbol/index. + beam_size : int + The width of beam. Greedy Search is used when beam_size = 1. + nbest : int + Number of hypotheses to keep. + lm_module : torch.nn.ModuleList + Neural networks modules for LM. + lm_weight : float + The weight of LM when performing beam search (λ). + log P(y|x) + λ log P_LM(y). (default: 0.3) + state_beam : float + The threshold coefficient in log space to decide if hyps in A (process_hyps) + is likely to compete with hyps in B (beam_hyps), if not, end the while loop. + Reference: https://arxiv.org/pdf/1911.01629.pdf + expand_beam : float + The threshold coefficient to limit the number of expanded hypotheses + that are added in A (process_hyp). + Reference: https://arxiv.org/pdf/1911.01629.pdf + Reference: https://github.com/kaldi-asr/kaldi/blob/master/src/decoder/simple-decoder.cc (See PruneToks) + + Example + ------- + searcher = TransducerBeamSearcher( + decode_network_lst=[hparams["emb"], hparams["dec"]], + tjoint=hparams["Tjoint"], + classifier_network=[hparams["transducer_lin"]], + blank_id=0, + beam_size=hparams["beam_size"], + nbest=hparams["nbest"], + lm_module=hparams["lm_model"], + lm_weight=hparams["lm_weight"], + state_beam=2.3, + expand_beam=2.3, + ) + >>> from speechbrain.nnet.transducer.transducer_joint import ( + ... Transducer_joint, + ... ) + >>> import speechbrain as sb + >>> emb = sb.nnet.embedding.Embedding( + ... num_embeddings=35, + ... embedding_dim=3, + ... consider_as_one_hot=True, + ... blank_id=0, + ... ) + >>> dec = sb.nnet.RNN.GRU( + ... hidden_size=10, input_shape=(1, 40, 34), bidirectional=False + ... ) + >>> lin = sb.nnet.linear.Linear(input_shape=(1, 40, 10), n_neurons=35) + >>> joint_network = sb.nnet.linear.Linear( + ... input_shape=(1, 1, 40, 35), n_neurons=35 + ... ) + >>> tjoint = Transducer_joint(joint_network, joint="sum") + >>> searcher = TransducerBeamSearcher( + ... decode_network_lst=[emb, dec], + ... tjoint=tjoint, + ... classifier_network=[lin], + ... blank_id=0, + ... beam_size=1, + ... nbest=1, + ... lm_module=None, + ... lm_weight=0.0, + ... ) + >>> enc = torch.rand([1, 20, 10]) + >>> hyps, _, _, _ = searcher(enc) + """ + + def __init__( + self, + decode_network_lst, + tjoint, + classifier_network, + blank_id, + beam_size=4, + nbest=5, + lm_module=None, + lm_weight=0.0, + state_beam=2.3, + expand_beam=2.3, + ): + super().__init__() + self.decode_network_lst = decode_network_lst + self.tjoint = tjoint + self.classifier_network = classifier_network + self.blank_id = blank_id + self.beam_size = beam_size + self.nbest = nbest + self.lm = lm_module + self.lm_weight = lm_weight + + if lm_module is None and lm_weight > 0: + raise ValueError("Language model is not provided.") + + self.state_beam = state_beam + self.expand_beam = expand_beam + self.softmax = torch.nn.LogSoftmax(dim=-1) + + if self.beam_size <= 1: + self.searcher = self.transducer_greedy_decode + else: + self.searcher = self.transducer_beam_search_decode + + def forward(self, tn_output): + """ + Arguments + --------- + tn_output : torch.Tensor + Output from transcription network with shape + [batch, time_len, hiddens]. + + Returns + ------- + Topk hypotheses + """ + + hyps = self.searcher(tn_output) + return hyps + + def transducer_greedy_decode( + self, + tn_output, + hidden_state=None, + return_hidden=False, + max_symbols_per_step=5, + ): + """Transducer greedy decoder is a greedy decoder over batch which apply Transducer rules: + 1- for each time step in the Transcription Network (TN) output: + -> Update the ith utterance only if + the previous target != the new one (we save the hiddens and the target) + -> otherwise: + ---> keep the previous target prediction from the decoder + + Arguments + --------- + tn_output : torch.Tensor + Output from transcription network with shape + [batch, time_len, hiddens]. + hidden_state : (torch.Tensor, torch.Tensor) + Hidden state to initially feed the decode network with. This is + useful in conjunction with `return_hidden` to be able to perform + beam search in a streaming context, so that you can reuse the last + hidden state as an initial state across calls. + return_hidden : bool + Whether the return tuple should contain an extra 5th element with + the hidden state at of the last step. See `hidden_state`. + max_symbols_per_step : int + Maximum number of non-blank symbols to decode per time step. This is + useful to avoid infinite loops. + + Returns + ------- + Tuple of 4 or 5 elements (if `return_hidden`). + + First element: List[List[int]] + List of decoded tokens + + Second element: torch.Tensor + Outputs a logits tensor [B,T,1,Output_Dim]; padding + has not been removed. + + Third element: None + nbest; irrelevant for greedy decode + + Fourth element: None + nbest scores; irrelevant for greedy decode + + Fifth element: Present if `return_hidden`, (torch.Tensor, torch.Tensor) + Tuple representing the hidden state required to call + `transducer_greedy_decode` where you left off in a streaming + context. + """ + hyp = { + "prediction": [[] for _ in range(tn_output.size(0))], + "logp_scores": [0.0 for _ in range(tn_output.size(0))], + } + # prepare BOS = Blank for the Prediction Network (PN) + input_PN = ( + torch.ones( + (tn_output.size(0), 1), + device=tn_output.device, + dtype=torch.int32, + ) + * self.blank_id + ) + + if hidden_state is None: + # First forward-pass on PN + out_PN, hidden = self._forward_PN(input_PN, self.decode_network_lst) + else: + out_PN, hidden = hidden_state + + # For each time step + for t_step in range(tn_output.size(1)): + count = 0 + while count <= max_symbols_per_step: # avoid infinite loop + # do unsqueeze over since tjoint must be have a 4 dim [B,T,U,Hidden] + log_probs = self._joint_forward_step( + tn_output[:, t_step, :].unsqueeze(1).unsqueeze(1), + out_PN.unsqueeze(1), + ) + # Sort outputs at time + logp_targets, positions = torch.max( + log_probs.squeeze(1).squeeze(1), dim=1 + ) + # Batch hidden update + have_update_hyp = [] + for i in range(positions.size(0)): + # Update hiddens only if + # 1- current prediction is non blank + if positions[i].item() != self.blank_id: + hyp["prediction"][i].append(positions[i].item()) + hyp["logp_scores"][i] += logp_targets[i] + input_PN[i][0] = positions[i] + have_update_hyp.append(i) + if len(have_update_hyp) > 0: + # Select sentence to update + # And do a forward steps + generated hidden + ( + selected_input_PN, + selected_hidden, + ) = self._get_sentence_to_update( + have_update_hyp, input_PN, hidden + ) + selected_out_PN, selected_hidden = self._forward_PN( + selected_input_PN, + self.decode_network_lst, + selected_hidden, + ) + # update hiddens and out_PN + out_PN[have_update_hyp] = selected_out_PN + hidden = self._update_hiddens( + have_update_hyp, selected_hidden, hidden + ) + else: + break + count += 1 + + ret = ( + hyp["prediction"], + torch.Tensor(hyp["logp_scores"]).exp().mean(), + None, + None, + ) + + if return_hidden: + # append the `(out_PN, hidden)` tuple to ret + ret += ( + ( + out_PN, + hidden, + ), + ) + + return ret + + def transducer_greedy_decode_streaming( + self, x: torch.Tensor, context: TransducerGreedySearcherStreamingContext + ): + """Tiny wrapper for + :meth:`~TransducerBeamSearcher.transducer_greedy_decode` with an API + that makes it suitable to be passed as a `decoding_function` for + streaming. + + Arguments + --------- + x : torch.Tensor + Outputs of the prediction network (equivalent to `tn_output`) + context : TransducerGreedySearcherStreamingContext + Mutable streaming context object, which must be specified and reused + across calls when streaming. + You can obtain an initial context by initializing a default object. + + Returns + ------- + hyp : torch.Tensor + """ + (hyp, _scores, _, _, hidden) = self.transducer_greedy_decode( + x, context.hidden, return_hidden=True + ) + context.hidden = hidden + return hyp + + def transducer_beam_search_decode(self, tn_output): + """Transducer beam search decoder is a beam search decoder over batch which apply Transducer rules: + 1- for each utterance: + 2- for each time steps in the Transcription Network (TN) output: + -> Do forward on PN and Joint network + -> Select topK <= beam + -> Do a while loop extending the hyps until we reach blank + -> otherwise: + --> extend hyp by the new token + + Arguments + --------- + tn_output : torch.Tensor + Output from transcription network with shape + [batch, time_len, hiddens]. + + Returns + ------- + torch.Tensor + Outputs a logits tensor [B,T,1,Output_Dim]; padding + has not been removed. + """ + + # min between beam and max_target_lent + nbest_batch = [] + nbest_batch_score = [] + for i_batch in range(tn_output.size(0)): + # if we use RNN LM keep there hiddens + # prepare BOS = Blank for the Prediction Network (PN) + # Prepare Blank prediction + blank = ( + torch.ones((1, 1), device=tn_output.device, dtype=torch.int32) + * self.blank_id + ) + input_PN = ( + torch.ones((1, 1), device=tn_output.device, dtype=torch.int32) + * self.blank_id + ) + # First forward-pass on PN + hyp = { + "prediction": [self.blank_id], + "logp_score": 0.0, + "hidden_dec": None, + } + if self.lm_weight > 0: + lm_dict = {"hidden_lm": None} + hyp.update(lm_dict) + beam_hyps = [hyp] + + # For each time step + for t_step in range(tn_output.size(1)): + # get hyps for extension + process_hyps = beam_hyps + beam_hyps = [] + while True: + if len(beam_hyps) >= self.beam_size: + break + # Add norm score + a_best_hyp = max( + process_hyps, + key=partial(get_transducer_key), + ) + + # Break if best_hyp in A is worse by more than state_beam than best_hyp in B + if len(beam_hyps) > 0: + b_best_hyp = max( + beam_hyps, + key=partial(get_transducer_key), + ) + a_best_prob = a_best_hyp["logp_score"] + b_best_prob = b_best_hyp["logp_score"] + if b_best_prob >= self.state_beam + a_best_prob: + break + + # remove best hyp from process_hyps + process_hyps.remove(a_best_hyp) + + # forward PN + input_PN[0, 0] = a_best_hyp["prediction"][-1] + out_PN, hidden = self._forward_PN( + input_PN, + self.decode_network_lst, + a_best_hyp["hidden_dec"], + ) + # do unsqueeze over since tjoint must be have a 4 dim [B,T,U,Hidden] + log_probs = self._joint_forward_step( + tn_output[i_batch, t_step, :] + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + out_PN.unsqueeze(0), + ) + + if self.lm_weight > 0: + log_probs_lm, hidden_lm = self._lm_forward_step( + input_PN, a_best_hyp["hidden_lm"] + ) + + # Sort outputs at time + logp_targets, positions = torch.topk( + log_probs.view(-1), k=self.beam_size, dim=-1 + ) + best_logp = ( + logp_targets[0] + if positions[0] != blank + else logp_targets[1] + ) + + # Extend hyp by selection + for j in range(logp_targets.size(0)): + # hyp + topk_hyp = { + "prediction": a_best_hyp["prediction"][:], + "logp_score": a_best_hyp["logp_score"] + + logp_targets[j], + "hidden_dec": a_best_hyp["hidden_dec"], + } + + if positions[j] == self.blank_id: + beam_hyps.append(topk_hyp) + if self.lm_weight > 0: + topk_hyp["hidden_lm"] = a_best_hyp["hidden_lm"] + continue + + if logp_targets[j] >= best_logp - self.expand_beam: + topk_hyp["prediction"].append(positions[j].item()) + topk_hyp["hidden_dec"] = hidden + if self.lm_weight > 0: + topk_hyp["hidden_lm"] = hidden_lm + topk_hyp["logp_score"] += ( + self.lm_weight + * log_probs_lm[0, 0, positions[j]] + ) + process_hyps.append(topk_hyp) + # Add norm score + nbest_hyps = sorted( + beam_hyps, + key=partial(get_transducer_key), + reverse=True, + )[: self.nbest] + all_predictions = [] + all_scores = [] + for hyp in nbest_hyps: + all_predictions.append(hyp["prediction"][1:]) + all_scores.append(hyp["logp_score"] / len(hyp["prediction"])) + nbest_batch.append(all_predictions) + nbest_batch_score.append(all_scores) + return ( + [nbest_utt[0] for nbest_utt in nbest_batch], + torch.Tensor( + [nbest_utt_score[0] for nbest_utt_score in nbest_batch_score] + ) + .exp() + .mean(), + nbest_batch, + nbest_batch_score, + ) + + def _joint_forward_step(self, h_i, out_PN): + """Join predictions (TN & PN).""" + + with torch.no_grad(): + # the output would be a tensor of [B,T,U, oneof[sum,concat](Hidden_TN,Hidden_PN)] + out = self.tjoint( + h_i, + out_PN, + ) + # forward the output layers + activation + save logits + out = self._forward_after_joint(out, self.classifier_network) + log_probs = self.softmax(out) + return log_probs + + def _lm_forward_step(self, inp_tokens, memory): + """This method should implement one step of + forwarding operation for language model. + + Arguments + --------- + inp_tokens : torch.Tensor + The input tensor of the current timestep. + memory : No limit + The memory variables input for this timestep. + (e.g., RNN hidden states). + + Return + ------ + log_probs : torch.Tensor + Log-probabilities of the current timestep output. + hs : No limit + The memory variables are generated in this timestep. + (e.g., RNN hidden states). + """ + with torch.no_grad(): + logits, hs = self.lm(inp_tokens, hx=memory) + log_probs = self.softmax(logits) + return log_probs, hs + + def _get_sentence_to_update(self, selected_sentences, output_PN, hidden): + """Select and return the updated hiddens and output + from the Prediction Network. + + Arguments + --------- + selected_sentences : list + List of updated sentences (indexes). + output_PN: torch.Tensor + Output tensor from prediction network (PN). + hidden : torch.Tensor + Optional: None, hidden tensor to be used for + recurrent layers in the prediction network. + + Returns + ------- + selected_output_PN: torch.Tensor + Outputs a logits tensor [B_selected,U, hiddens]. + hidden_update_hyp: torch.Tensor + Selected hiddens tensor. + """ + + selected_output_PN = output_PN[selected_sentences, :] + # for LSTM hiddens (hn, hc) + if isinstance(hidden, tuple): + hidden0_update_hyp = hidden[0][:, selected_sentences, :] + hidden1_update_hyp = hidden[1][:, selected_sentences, :] + hidden_update_hyp = (hidden0_update_hyp, hidden1_update_hyp) + else: + hidden_update_hyp = hidden[:, selected_sentences, :] + return selected_output_PN, hidden_update_hyp + + def _update_hiddens(self, selected_sentences, updated_hidden, hidden): + """Update hidden tensor by a subset of hidden tensor (updated ones). + + Arguments + --------- + selected_sentences : list + List of index to be updated. + updated_hidden : torch.Tensor + Hidden tensor of the selected sentences for update. + hidden : torch.Tensor + Hidden tensor to be updated. + + Returns + ------- + torch.Tensor + Updated hidden tensor. + """ + + if isinstance(hidden, tuple): + hidden[0][:, selected_sentences, :] = updated_hidden[0] + hidden[1][:, selected_sentences, :] = updated_hidden[1] + else: + hidden[:, selected_sentences, :] = updated_hidden + return hidden + + def _forward_PN(self, out_PN, decode_network_lst, hidden=None): + """Compute forward-pass through a list of prediction network (PN) layers. + + Arguments + --------- + out_PN : torch.Tensor + Input sequence from prediction network with shape + [batch, target_seq_lens]. + decode_network_lst: list + List of prediction network (PN) layers. + hidden : torch.Tensor + Optional: None, hidden tensor to be used for + recurrent layers in the prediction network + + Returns + ------- + out_PN : torch.Tensor + Outputs a logits tensor [B,U, hiddens]. + hidden : torch.Tensor + Hidden tensor to be used for the next step + by recurrent layers in prediction network. + """ + + for layer in decode_network_lst: + if layer.__class__.__name__ in [ + "RNN", + "LSTM", + "GRU", + "LiGRU", + "LiGRU_Layer", + ]: + out_PN, hidden = layer(out_PN, hidden) + else: + out_PN = layer(out_PN) + return out_PN, hidden + + def _forward_after_joint(self, out, classifier_network): + """Compute forward-pass through a list of classifier neural network. + + Arguments + --------- + out : torch.Tensor + Output from joint network with shape + [batch, target_len, time_len, hiddens] + classifier_network : list + List of output layers (after performing joint between TN and PN) + exp: (TN,PN) => joint => classifier_network_list [DNN block, Linear..] => chars prob + + Returns + ------- + torch.Tensor + Outputs a logits tensor [B, U,T, Output_Dim]; + """ + + for layer in classifier_network: + out = layer(out) + return out + + +def get_transducer_key(x): + """Argument function to customize the sort order (in sorted & max). + To be used as `key=partial(get_transducer_key)`. + + Arguments + --------- + x : dict + one of the items under comparison + + Returns + ------- + float + Normalized log-score. + """ + logp_key = x["logp_score"] / len(x["prediction"]) + return logp_key diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/utils.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/utils.py new file mode 100644 index 000000000..fcdd1b20a --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/decoders/utils.py @@ -0,0 +1,158 @@ +"""Utils functions for the decoding modules. + +Authors + * Adel Moumen 2023 + * Ju-Chieh Chou 2020 + * Peter Plantinga 2020 + * Mirco Ravanelli 2020 + * Sung-Lin Yeh 2020 +""" + +import torch + + +def _update_mem(inp_tokens, memory): + """This function is for updating the memory for transformer searches. + it is called at each decoding step. When being called, it appends the + predicted token of the previous step to existing memory. + + Arguments + --------- + inp_tokens : torch.Tensor + Predicted token of the previous decoding step. + memory : torch.Tensor + Contains all the predicted tokens. + + Returns + ------- + Updated memory + """ + if memory is None: + memory = torch.empty(inp_tokens.size(0), 0, device=inp_tokens.device) + return torch.cat([memory, inp_tokens.unsqueeze(1)], dim=-1) + + +def inflate_tensor(tensor, times, dim): + """This function inflates the tensor for times along dim. + + Arguments + --------- + tensor : torch.Tensor + The tensor to be inflated. + times : int + The tensor will inflate for this number of times. + dim : int + The dim to be inflated. + + Returns + ------- + torch.Tensor + The inflated tensor. + + Example + ------- + >>> tensor = torch.Tensor([[1, 2, 3], [4, 5, 6]]) + >>> new_tensor = inflate_tensor(tensor, 2, dim=0) + >>> new_tensor + tensor([[1., 2., 3.], + [1., 2., 3.], + [4., 5., 6.], + [4., 5., 6.]]) + """ + return torch.repeat_interleave(tensor, times, dim=dim) + + +def mask_by_condition(tensor, cond, fill_value): + """This function will mask some element in the tensor with fill_value, if condition=False. + + Arguments + --------- + tensor : torch.Tensor + The tensor to be masked. + cond : torch.BoolTensor + This tensor has to be the same size as tensor. + Each element represents whether to keep the value in tensor. + fill_value : float + The value to fill in the masked element. + + Returns + ------- + torch.Tensor + The masked tensor. + + Example + ------- + >>> tensor = torch.Tensor([[1, 2, 3], [4, 5, 6]]) + >>> cond = torch.BoolTensor([[True, True, False], [True, False, False]]) + >>> mask_by_condition(tensor, cond, 0) + tensor([[1., 2., 0.], + [4., 0., 0.]]) + """ + return torch.where(cond, tensor, fill_value) + + +def batch_filter_seq2seq_output(prediction, eos_id=-1): + """Calling batch_size times of filter_seq2seq_output. + + Arguments + --------- + prediction : list of torch.Tensor + A list containing the output ints predicted by the seq2seq system. + eos_id : int, string + The id of the eos. + + Returns + ------- + list + The output predicted by seq2seq model. + + Example + ------- + >>> predictions = [ + ... torch.IntTensor([1, 2, 3, 4]), + ... torch.IntTensor([2, 3, 4, 5, 6]), + ... ] + >>> predictions = batch_filter_seq2seq_output(predictions, eos_id=4) + >>> predictions + [[1, 2, 3], [2, 3]] + """ + outputs = [] + for p in prediction: + res = filter_seq2seq_output(p.tolist(), eos_id=eos_id) + outputs.append(res) + return outputs + + +def filter_seq2seq_output(string_pred, eos_id=-1): + """Filter the output until the first eos occurs (exclusive). + + Arguments + --------- + string_pred : list + A list containing the output strings/ints predicted by the seq2seq system. + eos_id : int, string + The id of the eos. + + Returns + ------- + list + The output predicted by seq2seq model. + + Example + ------- + >>> string_pred = ["a", "b", "c", "d", "eos", "e"] + >>> string_out = filter_seq2seq_output(string_pred, eos_id="eos") + >>> string_out + ['a', 'b', 'c', 'd'] + """ + if isinstance(string_pred, list): + try: + eos_index = next( + i for i, v in enumerate(string_pred) if v == eos_id + ) + except StopIteration: + eos_index = len(string_pred) + string_out = string_pred[:eos_index] + else: + raise ValueError("The input must be a list.") + return string_out diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/ASR.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/ASR.py new file mode 100644 index 000000000..4029208e8 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/ASR.py @@ -0,0 +1,1546 @@ +"""Specifies the inference interfaces for Automatic speech Recognition (ASR) modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023, 2024 + * Adel Moumen 2023, 2024, 2025 + * Pradnya Kandarkar 2023 +""" + +import functools +import itertools +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + +import sentencepiece +import torch +import torchaudio +from tqdm import tqdm + +import speechbrain +from speechbrain.inference.interfaces import Pretrained +from speechbrain.utils.data_utils import split_path +from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig +from speechbrain.utils.fetching import fetch +from speechbrain.utils.streaming import split_fixed_chunks + + +class EncoderDecoderASR(Pretrained): + """A ready-to-use Encoder-Decoder ASR model + + The class can be used either to run only the encoder (encode()) to extract + features or to run the entire encoder-decoder model + (transcribe()) to transcribe speech. The given YAML must contain the fields + specified in the *_NEEDED[] lists. + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> from speechbrain.inference.ASR import EncoderDecoderASR + >>> tmpdir = getfixture("tmpdir") + >>> asr_model = EncoderDecoderASR.from_hparams( + ... source="speechbrain/asr-crdnn-rnnlm-librispeech", + ... savedir=tmpdir, + ... ) # doctest: +SKIP + >>> asr_model.transcribe_file( + ... "tests/samples/single-mic/example2.flac" + ... ) # doctest: +SKIP + "MY FATHER HAS REVEALED THE CULPRIT'S NAME" + """ + + HPARAMS_NEEDED = ["tokenizer"] + MODULES_NEEDED = ["encoder", "decoder"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokenizer = self.hparams.tokenizer + self.transducer_beam_search = False + self.transformer_beam_search = False + if hasattr(self.hparams, "transducer_beam_search"): + self.transducer_beam_search = self.hparams.transducer_beam_search + if hasattr(self.hparams, "transformer_beam_search"): + self.transformer_beam_search = self.hparams.transformer_beam_search + + def transcribe_file(self, path, **kwargs): + """Transcribes the given audiofile into a sequence of words. + + Arguments + --------- + path : str + Path to audio file which to transcribe. + **kwargs : dict + Arguments forwarded to ``load_audio``. + + Returns + ------- + str + The audiofile transcription produced by this ASR system. + """ + waveform = self.load_audio(path, **kwargs) + # Fake a batch: + batch = waveform.unsqueeze(0) + rel_length = torch.tensor([1.0]) + predicted_words, predicted_tokens = self.transcribe_batch( + batch, rel_length + ) + return predicted_words[0] + + def encode_batch(self, wavs, wav_lens): + """Encodes the input audio into a sequence of hidden states + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + torch.Tensor + The encoded batch + """ + wavs = wavs.float() + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + encoder_out = self.mods.encoder(wavs, wav_lens) + if self.transformer_beam_search: + encoder_out = self.mods.transformer.encode(encoder_out, wav_lens) + return encoder_out + + def transcribe_batch(self, wavs, wav_lens): + """Transcribes the input audio into a sequence of words + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + list + Each waveform in the batch transcribed. + tensor + Each predicted token id. + """ + with torch.no_grad(): + wav_lens = wav_lens.to(self.device) + encoder_out = self.encode_batch(wavs, wav_lens) + if self.transducer_beam_search: + inputs = [encoder_out] + else: + inputs = [encoder_out, wav_lens] + predicted_tokens, _, _, _ = self.mods.decoder(*inputs) + predicted_words = [ + self.tokenizer.decode_ids(token_seq) + for token_seq in predicted_tokens + ] + return predicted_words, predicted_tokens + + def forward(self, wavs, wav_lens): + """Runs full transcription - note: no gradients through decoding""" + return self.transcribe_batch(wavs, wav_lens) + + +class EncoderASR(Pretrained): + """A ready-to-use Encoder ASR model + + The class can be used either to run only the encoder (encode()) to extract + features or to run the entire encoder + decoder function model + (transcribe()) to transcribe speech. The given YAML must contain the fields + specified in the *_NEEDED[] lists. + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> from speechbrain.inference.ASR import EncoderASR + >>> tmpdir = getfixture("tmpdir") + >>> asr_model = EncoderASR.from_hparams( + ... source="speechbrain/asr-wav2vec2-commonvoice-fr", + ... savedir=tmpdir, + ... ) # doctest: +SKIP + >>> asr_model.transcribe_file( + ... "samples/audio_samples/example_fr.wav" + ... ) # doctest: +SKIP + """ + + HPARAMS_NEEDED = ["tokenizer", "decoding_function"] + MODULES_NEEDED = ["encoder"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.tokenizer = self.hparams.tokenizer + self.set_decoding_function() + + def set_decoding_function(self): + """Set the decoding function based on the parameters defined in the hyperparameter file. + + The decoding function is determined by the `decoding_function` specified in the hyperparameter file. + It can be either a functools.partial object representing a decoding function or an instance of + `speechbrain.decoders.ctc.CTCBaseSearcher` for beam search decoding. + + Raises: + ValueError: If the decoding function is neither a functools.partial nor an instance of + speechbrain.decoders.ctc.CTCBaseSearcher. + + Note: + - For greedy decoding (functools.partial), the provided `decoding_function` is assigned directly. + - For CTCBeamSearcher decoding, an instance of the specified `decoding_function` is created, and + additional parameters are added based on the tokenizer type. + """ + # Greedy Decoding case + if isinstance(self.hparams.decoding_function, functools.partial): + self.decoding_function = self.hparams.decoding_function + # CTCBeamSearcher case + else: + # 1. check if the decoding function is an instance of speechbrain.decoders.CTCBaseSearcher + if issubclass( + self.hparams.decoding_function, + speechbrain.decoders.ctc.CTCBaseSearcher, + ): + # If so, we need to retrieve the vocab list from the tokenizer. + # We also need to check if the tokenizer is a sentencepiece or a CTCTextEncoder. + if isinstance( + self.tokenizer, speechbrain.dataio.encoder.CTCTextEncoder + ): + ind2lab = self.tokenizer.ind2lab + vocab_list = [ind2lab[x] for x in range(len(ind2lab))] + elif isinstance( + self.tokenizer, sentencepiece.SentencePieceProcessor + ): + vocab_list = [ + self.tokenizer.id_to_piece(i) + for i in range(self.tokenizer.vocab_size()) + ] + else: + raise ValueError( + "The tokenizer must be sentencepiece or CTCTextEncoder" + ) + + # We can now instantiate the decoding class and add all the parameters + if hasattr(self.hparams, "test_beam_search"): + opt_beam_search_params = self.hparams.test_beam_search + # check if the kenlm_model_path is provided and fetch it if necessary + if "kenlm_model_path" in opt_beam_search_params: + source, fl = split_path( + opt_beam_search_params["kenlm_model_path"] + ) + kenlm_model_path = str( + fetch( + fl, source=source, savedir=self.hparams.savedir + ) + ) + # we need to update the kenlm_model_path in the opt_beam_search_params + opt_beam_search_params["kenlm_model_path"] = ( + kenlm_model_path + ) + else: + opt_beam_search_params = {} + self.decoding_function = self.hparams.decoding_function( + **opt_beam_search_params, vocab_list=vocab_list + ) + else: + raise ValueError( + "The decoding function must be an instance of speechbrain.decoders.CTCBaseSearcher" + ) + + def transcribe_file(self, path, **kwargs): + """Transcribes the given audiofile into a sequence of words. + + Arguments + --------- + path : str + Path to audio file which to transcribe. + **kwargs : dict + Arguments forwarded to ``load_audio``. + + Returns + ------- + str + The audiofile transcription produced by this ASR system. + """ + waveform = self.load_audio(path, **kwargs) + # Fake a batch: + batch = waveform.unsqueeze(0) + rel_length = torch.tensor([1.0]) + predicted_words, predicted_tokens = self.transcribe_batch( + batch, rel_length + ) + return str(predicted_words[0]) + + def encode_batch(self, wavs, wav_lens): + """Encodes the input audio into a sequence of hidden states + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + torch.Tensor + The encoded batch + """ + wavs = wavs.float() + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + encoder_out = self.mods.encoder(wavs, wav_lens) + return encoder_out + + def transcribe_batch(self, wavs, wav_lens): + """Transcribes the input audio into a sequence of words + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + list + Each waveform in the batch transcribed. + tensor + Each predicted token id. + """ + with torch.no_grad(): + wav_lens = wav_lens.to(self.device) + encoder_out = self.encode_batch(wavs, wav_lens) + predictions = self.decoding_function(encoder_out, wav_lens) + is_ctc_text_encoder_tokenizer = isinstance( + self.tokenizer, speechbrain.dataio.encoder.CTCTextEncoder + ) + if isinstance(self.hparams.decoding_function, functools.partial): + if is_ctc_text_encoder_tokenizer: + predicted_words = [ + "".join(self.tokenizer.decode_ndim(token_seq)) + for token_seq in predictions + ] + else: + predicted_words = [ + self.tokenizer.decode_ids(token_seq) + for token_seq in predictions + ] + else: + predicted_words = [hyp[0].text for hyp in predictions] + + return predicted_words, predictions + + def forward(self, wavs, wav_lens): + """Runs the encoder""" + return self.encode_batch(wavs, wav_lens) + + +@dataclass +class ASRWhisperSegment: + """A single chunk of audio for Whisper ASR streaming. + + This object is intended to be mutated as streaming progresses and passed across calls + to the lower-level APIs such as `encode_chunk`, `decode_chunk`, etc. + + Attributes + ---------- + start : float + The start time of the audio chunk. + end : float + The end time of the audio chunk. + chunk : torch.Tensor + The audio chunk, shape [time, channels]. + lang_id : str + The language identifier associated with the audio chunk. + words : str + The predicted words for the audio chunk. + tokens : List[int] + The predicted tokens for the audio chunk. + prompt : List[str] + The prompt associated with the audio chunk. + avg_log_probs : float + The average log probability associated with the prediction. + no_speech_prob : float + The probability of no speech in the audio chunk. + """ + + start: float + end: float + chunk: torch.Tensor + lang_id: Optional[str] = None + words: Optional[str] = None + tokens: Optional[List[str]] = None + prompt: Optional[List[str]] = None + avg_log_probs: Optional[float] = None + no_speech_prob: Optional[float] = None + + +class WhisperASR(Pretrained): + """A ready-to-use Whisper ASR model. + + The class can be used to run the entire encoder-decoder whisper model. + The set of tasks supported are: ``transcribe``, ``translate``, and ``lang_id``. + The given YAML must contains the fields specified in the *_NEEDED[] lists. + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> from speechbrain.inference.ASR import WhisperASR + >>> tmpdir = getfixture("tmpdir") + >>> asr_model = WhisperASR.from_hparams( + ... source="speechbrain/asr-whisper-medium-commonvoice-it", + ... savedir=tmpdir, + ... ) # doctest: +SKIP + >>> hyp = asr_model.transcribe_file( + ... "speechbrain/asr-whisper-medium-commonvoice-it/example-it.wav" + ... ) # doctest: +SKIP + >>> hyp # doctest: +SKIP + buongiorno a tutti e benvenuti a bordo + >>> _, probs = asr_model.detect_language_file( + ... "speechbrain/asr-whisper-medium-commonvoice-it/example-it.wav" + ... ) # doctest: +SKIP + >>> print( + ... f"Detected language: {max(probs[0], key=probs[0].get)}" + ... ) # doctest: +SKIP + Detected language: it + """ + + HPARAMS_NEEDED = ["language", "sample_rate"] + MODULES_NEEDED = ["whisper", "decoder"] + TASKS = ["transcribe", "translate", "lang_id"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokenizer = self.hparams.whisper.tokenizer + + @torch.no_grad() + def detect_language_file(self, path: str): + """Detects the language of the given audiofile. + This method only works on input_file of 30 seconds or less. + + Arguments + --------- + path : str + Path to audio file which to transcribe. + + Returns + ------- + language_tokens : torch.Tensor + The detected language tokens. + language_probs : dict + The probabilities of the detected language tokens. + + Raises + ------ + ValueError + If the model doesn't have language tokens. + """ + wavs = self.load_audio(path).float().to(self.device).unsqueeze(0) + mel = self.mods.whisper._get_mel(wavs) + language_tokens, language_probs = self.mods.whisper.detect_language(mel) + return language_tokens, language_probs + + @torch.no_grad() + def detect_language_batch(self, wav: torch.Tensor): + """Detects the language of the given wav Tensor. + This method only works on wav files of 30 seconds or less. + + Arguments + --------- + wav : torch.tensor + Batch of waveforms [batch, time, channels]. + + Returns + ------- + language_tokens : torch.Tensor of shape (batch_size,) + ids of the most probable language tokens, which appears after the startoftranscript token. + language_probs : List[Dict[str, float]] + list of dictionaries containing the probability distribution over all languages. + + Raises + ------ + ValueError + If the model doesn't have language tokens. + + Example + ------- + >>> from speechbrain.inference.ASR import WhisperASR + >>> from speechbrain.dataio import audio_io + >>> tmpdir = getfixture("tmpdir") + >>> asr_model = WhisperASR.from_hparams( + ... source="speechbrain/asr-whisper-medium-commonvoice-it", + ... savedir=tmpdir, + ... ) # doctest: +SKIP + >>> wav, _ = audio_io.load("your_audio") # doctest: +SKIP + >>> language_tokens, language_probs = asr_model.detect_language( + ... wav + ... ) # doctest: +SKIP + """ + mel = self.mods.whisper._get_mel(wav) + language_tokens, language_probs = self.mods.whisper.detect_language(mel) + return language_tokens, language_probs + + @torch.no_grad() + def _detect_language(self, mel: torch.Tensor, task: str): + """Detects the language of the given mel spectrogram. + + Arguments + --------- + mel : torch.tensor + Batch of mel spectrograms [batch, time, channels]. + task : str + The task to perform. + + Returns + ------- + language_tokens : Tensor, shape = (n_audio,) + ids of the most probable language tokens, which appears after the startoftranscript token. + language_probs : List[Dict[str, float]], length = n_audio + list of dictionaries containing the probability distribution over all languages. + """ + languages = [self.mods.whisper.language] * mel.shape[0] + lang_probs = None + + if self.mods.whisper.language is None or task == "lang_id": + lang_tokens, lang_probs = self.mods.whisper.detect_language(mel) + languages = [max(probs, key=probs.get) for probs in lang_probs] + self.mods.decoder.set_lang_tokens(lang_tokens) + return languages, lang_probs + + def _get_audio_stream( + self, streamer: "torchaudio.io.StreamReader", frames_per_chunk: int + ): + """From a :class:`torchaudio.io.StreamReader`, identifies the audio + stream and returns an iterable stream of chunks (after resampling and + downmixing to mono). + + Arguments + --------- + streamer : torchaudio.io.StreamReader + The stream object. Must hold exactly one source stream of an + audio type. + frames_per_chunk : int + The number of frames per chunk. For a streaming model, this should + be determined from the DynChunkTrain configuration. + + Yields + ------ + chunks from streamer + """ + + stream_infos = [ + streamer.get_src_stream_info(i) + for i in range(streamer.num_src_streams) + ] + + audio_stream_infos = [ + (i, stream_info) + for i, stream_info in enumerate(stream_infos) + if stream_info.media_type == "audio" + ] + + if len(audio_stream_infos) != 1: + raise ValueError( + f"Expected stream to have only 1 stream (with any number of channels), got {len(audio_stream_infos)} (with streams: {stream_infos})" + ) + + # find the index of the first (and only) audio stream + audio_stream_index = audio_stream_infos[0][0] + + # output stream #0 + streamer.add_basic_audio_stream( + frames_per_chunk=frames_per_chunk, + stream_index=audio_stream_index, + sample_rate=self.audio_normalizer.sample_rate, + format="fltp", # torch.float32 + num_channels=1, + buffer_chunk_size=-1, # avoiding the problem of dropping first chunks + ) + + for (chunk,) in streamer.stream(): + chunk = chunk.squeeze(-1) # we deal with mono, remove that dim + chunk = chunk.unsqueeze(0) # create a fake batch dim + yield chunk + + @torch.no_grad() + def transcribe_file_streaming( + self, + path: str, + task: Optional[str] = None, + initial_prompt: Optional[str] = None, + logprob_threshold: Optional[float] = -1.0, + no_speech_threshold=0.6, + condition_on_previous_text: bool = False, + verbose: bool = False, + use_torchaudio_streaming: bool = False, + chunk_size: int = 30, + **kwargs, + ): + """Transcribes the given audiofile into a sequence of words. + This method supports the following tasks: ``transcribe``, ``translate``, and ``lang_id``. + It can process an input audio file longer than 30 seconds by splitting it into chunk_size-second segments. + + Arguments + --------- + path : str + URI/path to the audio to transcribe. When + ``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow + fetching from HF or a local file. When ``True``, resolves the URI + through ffmpeg, as documented in + :class:`torchaudio.io.StreamReader`. + task : Optional[str] + The task to perform. If None, the default task is the one passed in the Whisper model. + initial_prompt : Optional[str] + The initial prompt to condition the model on. + logprob_threshold : Optional[float] + The log probability threshold to continue decoding the current segment. + no_speech_threshold : float + The threshold to skip decoding segment if the no_speech_prob is higher than this value. + condition_on_previous_text : bool + If True, the model will be condition on the last 224 tokens. + verbose : bool + If True, print the transcription of each segment. + use_torchaudio_streaming : bool + Whether the audio file can be loaded in a streaming fashion. If not, + transcription is still performed through chunks of audio, but the + entire audio file is fetched and loaded at once. + This skips the usual fetching method and instead resolves the URI + using torchaudio (via ffmpeg). + chunk_size : int + The size of the chunks to split the audio into. The default + chunk size is 30 seconds which corresponds to the maximal length + that the model can process in one go. + **kwargs : dict + Arguments forwarded to ``load_audio`` + + Yields + ------ + ASRWhisperSegment + A new ASRWhisperSegment instance initialized with the provided parameters. + """ + if task is not None: + if task in self.TASKS: + if task != "lang_id": + self.mods.decoder.set_task(task) + else: + raise ValueError( + f"Task {task} not supported. Supported tasks are {self.TASKS}" + ) + + # create chunks of chunk_size seconds + num_frames_per_chunk = chunk_size * self.hparams.sample_rate + if use_torchaudio_streaming: + streamer = torchaudio.io.StreamReader(path) + segments = self._get_audio_stream(streamer, num_frames_per_chunk) + else: + waveform = self.load_audio(path, **kwargs) + batch = waveform.unsqueeze(0) + segments = split_fixed_chunks(batch, num_frames_per_chunk) + + rel_length = torch.tensor([1.0]) + + all_tokens = [] + prompt_reset_since = 0 + if initial_prompt is not None: + initial_prompt_tokens = self.whisper.tokenizer.encode( + " " + initial_prompt.strip() + ) + all_tokens.extend(initial_prompt_tokens) + else: + initial_prompt_tokens = [] + + for i, segment in enumerate(tqdm(segments, disable=verbose)): + # move the segment on the device + segment = segment.to(self.device) + + # extract mel spectrogram + mel_segment = self.mods.whisper._get_mel(segment) + + start = i * chunk_size + end = (i + 1) * chunk_size + + encoder_out = self.mods.whisper.forward_encoder(mel_segment) + languages, _ = self._detect_language(mel_segment, task) + + if task == "lang_id": + yield ASRWhisperSegment( + start=start, + end=end, + chunk=segment, + lang_id=languages[0], + ) + continue + + prompt = all_tokens[prompt_reset_since:] + self.mods.decoder.set_prompt(prompt) + + predicted_tokens, _, scores, _ = self.mods.decoder( + encoder_out, rel_length + ) + avg_log_probs = scores.sum() / (len(predicted_tokens[0]) + 1) + + if no_speech_threshold is not None: + should_skip = ( + self.mods.decoder.no_speech_probs[0] > no_speech_threshold + ) + if ( + logprob_threshold is not None + and avg_log_probs > logprob_threshold + ): + # don't skip if the logprob is high enough, despite the no_speech_prob + should_skip = False + + if should_skip: + yield ASRWhisperSegment( + start=start, + end=end, + chunk=segment, + lang_id=languages[0], + words="", + tokens=[], + prompt=prompt, + avg_log_probs=avg_log_probs.item(), + no_speech_prob=self.mods.decoder.no_speech_probs[0], + ) + continue + + predicted_words = [ + self.tokenizer.decode(t, skip_special_tokens=True).strip() + for t in predicted_tokens + ] + + yield ASRWhisperSegment( + start=start, + end=end, + chunk=segment, + lang_id=languages[0], + words=predicted_words[0], + tokens=predicted_tokens[0], + prompt=prompt, + avg_log_probs=avg_log_probs.item(), + no_speech_prob=self.mods.decoder.no_speech_probs[0], + ) + + all_tokens.extend(predicted_tokens[0]) + + if ( + not condition_on_previous_text + or self.mods.decoder.temperature > 0.5 + ): + prompt_reset_since = len(all_tokens) + + def transcribe_file( + self, + path: str, + task: Optional[str] = None, + initial_prompt: Optional[str] = None, + logprob_threshold: Optional[float] = -1.0, + no_speech_threshold=0.6, + condition_on_previous_text: bool = False, + verbose: bool = False, + use_torchaudio_streaming: bool = False, + chunk_size: Optional[int] = 30, + **kwargs, + ) -> List[ASRWhisperSegment]: + """Run the Whisper model using the specified task on the given audio file and return the ``ASRWhisperSegment`` objects + for each segment. + + This method supports the following tasks: ``transcribe``, ``translate``, and ``lang_id``. + It can process an input audio file longer than 30 seconds by splitting it into chunk_size-second segments. + + Arguments + --------- + path : str + URI/path to the audio to transcribe. When + ``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow + fetching from HF or a local file. When ``True``, resolves the URI + through ffmpeg, as documented in + :class:`torchaudio.io.StreamReader`. + task : Optional[str] + The task to perform. If None, the default task is the one passed in the Whisper model. + It can be one of the following: ``transcribe``, ``translate``, ``lang_id``. + initial_prompt : Optional[str] + The initial prompt to condition the model on. + logprob_threshold : Optional[float] + The log probability threshold to continue decoding the current segment. + no_speech_threshold : float + The threshold to skip decoding segment if the no_speech_prob is higher than this value. + condition_on_previous_text : bool + If True, the model will be condition on the last 224 tokens. + verbose : bool + If True, print the details of each segment. + use_torchaudio_streaming : bool + Whether the audio file can be loaded in a streaming fashion. If not, + transcription is still performed through chunks of audio, but the + entire audio file is fetched and loaded at once. + This skips the usual fetching method and instead resolves the URI + using torchaudio (via ffmpeg). + chunk_size : Optional[int] + The size of the chunks to split the audio into. The default + chunk size is 30 seconds which corresponds to the maximal length + that the model can process in one go. + **kwargs : dict + Arguments forwarded to ``load_audio`` + + Returns + ------- + results : list + A list of ``WhisperASRChunk`` objects, each containing the task result. + """ + results = [] + for whisper_segment in self.transcribe_file_streaming( + path, + task=task, + initial_prompt=initial_prompt, + logprob_threshold=logprob_threshold, + no_speech_threshold=no_speech_threshold, + condition_on_previous_text=condition_on_previous_text, + verbose=verbose, + use_torchaudio_streaming=use_torchaudio_streaming, + chunk_size=chunk_size, + **kwargs, + ): + results.append(whisper_segment) + if verbose: + pred = ( + whisper_segment.words + if task != "lang_id" + else whisper_segment.lang_id + ) + print( + f"[{whisper_segment.start}s --> {whisper_segment.end}s] {pred}" + ) + return results + + def encode_batch(self, wavs, wav_lens): + """Encodes the input audio into a sequence of hidden states + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.tensor + Batch of waveforms [batch, time, channels]. + wav_lens : torch.tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + torch.tensor + The encoded batch + """ + wavs = wavs.to(device=self.device, dtype=torch.float32) + mel = self.mods.whisper._get_mel(wavs) + encoder_out = self.mods.whisper.forward_encoder(mel) + return encoder_out + + @torch.no_grad() + def transcribe_batch(self, wavs, wav_lens): + """Transcribes the input audio into a sequence of words + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.tensor + Batch of waveforms [batch, time, channels]. + wav_lens : torch.tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + list + Each waveform in the batch transcribed. + tensor + Each predicted token id. + """ + wav_lens = wav_lens.float().to(self.device) + encoder_out = self.encode_batch(wavs, wav_lens) + predicted_tokens, _, _, _ = self.mods.decoder(encoder_out, wav_lens) + predicted_words = [ + self.tokenizer.decode(t, skip_special_tokens=True).strip() + for t in predicted_tokens + ] + if self.hparams.normalized_transcripts: + predicted_words = [ + self.tokenizer.normalize(text).split(" ") + for text in predicted_words + ] + + return predicted_words, predicted_tokens + + def forward(self, wavs, wav_lens): + """Runs full transcription - note: no gradients through decoding""" + return self.transcribe_batch(wavs, wav_lens) + + +@dataclass +class ASRStreamingContext: + """Streaming metadata, initialized by + :meth:`~StreamingASR.make_streaming_context` (see there for details on + initialization of fields here). + + This object is intended to be mutate: the same object should be passed + across calls as streaming progresses (namely when using the lower-level + :meth:`~StreamingASR.encode_chunk`, etc. APIs). + + Holds some references to opaque streaming contexts, so the context is + model-agnostic to an extent.""" + + config: DynChunkTrainConfig + """Dynamic chunk training configuration used to initialize the streaming + context. Cannot be modified on the fly.""" + + fea_extractor_context: Any + """Opaque feature extractor streaming context.""" + + encoder_context: Any + """Opaque encoder streaming context.""" + + decoder_context: Any + """Opaque decoder streaming context.""" + + tokenizer_context: Optional[List[Any]] + """Opaque streaming context for the tokenizer. Initially `None`. Initialized + to a list of tokenizer contexts once batch size can be determined.""" + + +class StreamingASR(Pretrained): + """A ready-to-use, streaming-capable ASR model. + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> from speechbrain.inference.ASR import StreamingASR + >>> from speechbrain.utils.dynamic_chunk_training import DynChunkTrainConfig + >>> tmpdir = getfixture("tmpdir") + >>> asr_model = StreamingASR.from_hparams( + ... source="speechbrain/asr-conformer-streaming-librispeech", + ... savedir=tmpdir, + ... ) # doctest: +SKIP + >>> asr_model.transcribe_file( + ... "speechbrain/asr-conformer-streaming-librispeech/test-en.wav", + ... DynChunkTrainConfig(24, 8), + ... ) # doctest: +SKIP + """ + + HPARAMS_NEEDED = [ + "fea_streaming_extractor", + "make_decoder_streaming_context", + "decoding_function", + "make_tokenizer_streaming_context", + "tokenizer_decode_streaming", + ] + MODULES_NEEDED = ["enc", "proj_enc"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.filter_props = self.hparams.fea_streaming_extractor.properties + + def _get_audio_stream( + self, streamer: "torchaudio.io.StreamReader", frames_per_chunk: int + ): + """From a :class:`torchaudio.io.StreamReader`, identifies the audio + stream and returns an iterable stream of chunks (after resampling and + downmixing to mono). + + Arguments + --------- + streamer : torchaudio.io.StreamReader + The stream object. Must hold exactly one source stream of an + audio type. + frames_per_chunk : int + The number of frames per chunk. For a streaming model, this should + be determined from the DynChunkTrain configuration. + + Yields + ------ + chunks from streamer + """ + + stream_infos = [ + streamer.get_src_stream_info(i) + for i in range(streamer.num_src_streams) + ] + + audio_stream_infos = [ + (i, stream_info) + for i, stream_info in enumerate(stream_infos) + if stream_info.media_type == "audio" + ] + + if len(audio_stream_infos) != 1: + raise ValueError( + f"Expected stream to have only 1 stream (with any number of channels), got {len(audio_stream_infos)} (with streams: {stream_infos})" + ) + + # find the index of the first (and only) audio stream + audio_stream_index = audio_stream_infos[0][0] + + # output stream #0 + streamer.add_basic_audio_stream( + frames_per_chunk=frames_per_chunk, + stream_index=audio_stream_index, + sample_rate=self.audio_normalizer.sample_rate, + format="fltp", # torch.float32 + num_channels=1, + ) + + for (chunk,) in streamer.stream(): + chunk = chunk.squeeze(-1) # we deal with mono, remove that dim + chunk = chunk.unsqueeze(0) # create a fake batch dim + yield chunk + + def transcribe_file_streaming( + self, + path, + dynchunktrain_config: DynChunkTrainConfig, + use_torchaudio_streaming: bool = True, + **kwargs, + ): + """Transcribes the given audio file into a sequence of words, in a + streaming fashion, meaning that text is being yield from this + generator, in the form of strings to concatenate. + + Arguments + --------- + path : str + URI/path to the audio to transcribe. When + ``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow + fetching from HF or a local file. When ``True``, resolves the URI + through ffmpeg, as documented in + :class:`torchaudio.io.StreamReader`. + dynchunktrain_config : DynChunkTrainConfig + Streaming configuration. Sane values and how much time chunks + actually represent is model-dependent. + use_torchaudio_streaming : bool + Whether the audio file can be loaded in a streaming fashion. If not, + transcription is still performed through chunks of audio, but the + entire audio file is fetched and loaded at once. + This skips the usual fetching method and instead resolves the URI + using torchaudio (via ffmpeg). + **kwargs : dict + Arguments forwarded to ``load_audio`` + + Yields + ------ + generator of str + An iterator yielding transcribed chunks (strings). There is a yield + for every chunk, even if the transcribed string for that chunk is an + empty string. + """ + + chunk_size = self.get_chunk_size_frames(dynchunktrain_config) + + if use_torchaudio_streaming: + streamer = torchaudio.io.StreamReader(path) + chunks = self._get_audio_stream(streamer, chunk_size) + else: + waveform = self.load_audio(path, **kwargs) + batch = waveform.unsqueeze(0) # create batch dim + chunks = split_fixed_chunks(batch, chunk_size) + + rel_length = torch.tensor([1.0]) + context = self.make_streaming_context(dynchunktrain_config) + + final_chunks = ( + [torch.zeros((1, chunk_size), device=self.device)] + * self.hparams.fea_streaming_extractor.get_recommended_final_chunk_count( + chunk_size + ) + ) + + for chunk in itertools.chain(chunks, final_chunks): + predicted_words = self.transcribe_chunk(context, chunk, rel_length) + yield predicted_words[0] + + def transcribe_file( + self, + path, + dynchunktrain_config: DynChunkTrainConfig, + use_torchaudio_streaming: bool = True, + ): + """Transcribes the given audio file into a sequence of words. + + Arguments + --------- + path : str + URI/path to the audio to transcribe. When + ``use_torchaudio_streaming`` is ``False``, uses SB fetching to allow + fetching from HF or a local file. When ``True``, resolves the URI + through ffmpeg, as documented in + :class:`torchaudio.io.StreamReader`. + dynchunktrain_config : DynChunkTrainConfig + Streaming configuration. Sane values and how much time chunks + actually represent is model-dependent. + use_torchaudio_streaming : bool + Whether the audio file can be loaded in a streaming fashion. If not, + transcription is still performed through chunks of audio, but the + entire audio file is fetched and loaded at once. + This skips the usual fetching method and instead resolves the URI + using torchaudio (via ffmpeg). + + Returns + ------- + str + The audio file transcription produced by this ASR system. + """ + + pred = "" + + for text_chunk in self.transcribe_file_streaming( + path, dynchunktrain_config, use_torchaudio_streaming + ): + pred += text_chunk + + return pred + + def make_streaming_context(self, dynchunktrain_config: DynChunkTrainConfig): + """Create a blank streaming context to be passed around for chunk + encoding/transcription. + + Arguments + --------- + dynchunktrain_config : DynChunkTrainConfig + Streaming configuration. Sane values and how much time chunks + actually represent is model-dependent. + + Returns + ------- + ASRStreamingContext + """ + + return ASRStreamingContext( + config=dynchunktrain_config, + fea_extractor_context=self.hparams.fea_streaming_extractor.make_streaming_context(), + encoder_context=self.mods.enc.make_streaming_context( + dynchunktrain_config + ), + decoder_context=self.hparams.make_decoder_streaming_context(), + tokenizer_context=None, + ) + + def get_chunk_size_frames( + self, dynchunktrain_config: DynChunkTrainConfig + ) -> int: + """Returns the chunk size in actual audio samples, i.e. the exact + expected length along the time dimension of an input chunk tensor (as + passed to :meth:`~StreamingASR.encode_chunk` and similar low-level + streaming functions). + + Arguments + --------- + dynchunktrain_config : DynChunkTrainConfig + The streaming configuration to determine the chunk frame count of. + + Returns + ------- + chunk size + """ + + return (self.filter_props.stride - 1) * dynchunktrain_config.chunk_size + + @torch.no_grad() + def encode_chunk( + self, + context: ASRStreamingContext, + chunk: torch.Tensor, + chunk_len: Optional[torch.Tensor] = None, + ): + """Encoding of a batch of audio chunks into a batch of encoded + sequences. + For full speech-to-text offline transcription, use `transcribe_batch` or + `transcribe_file`. + Must be called over a given context in the correct order of chunks over + time. + + Arguments + --------- + context : ASRStreamingContext + Mutable streaming context object, which must be specified and reused + across calls when streaming. + You can obtain an initial context by calling + `asr.make_streaming_context(config)`. + + chunk : torch.Tensor + The tensor for an audio chunk of shape `[batch size, time]`. + The time dimension must strictly match + `asr.get_chunk_size_frames(config)`. + The waveform is expected to be in the model's expected format (i.e. + the sampling rate must be correct). + + chunk_len : torch.Tensor, optional + The relative chunk length tensor of shape `[batch size]`. This is to + be used when the audio in one of the chunks of the batch is ending + within this chunk. + If unspecified, equivalent to `torch.ones((batch_size,))`. + + Returns + ------- + torch.Tensor + Encoded output, of a model-dependent shape.""" + + if chunk_len is None: + chunk_len = torch.ones((chunk.size(0),)) + + chunk = chunk.float() + chunk, chunk_len = chunk.to(self.device), chunk_len.to(self.device) + + assert chunk.shape[-1] <= self.get_chunk_size_frames(context.config) + + x = self.hparams.fea_streaming_extractor( + chunk, context=context.fea_extractor_context, lengths=chunk_len + ) + x = self.mods.enc.forward_streaming(x, context.encoder_context) + x = self.mods.proj_enc(x) + return x + + @torch.no_grad() + def decode_chunk( + self, context: ASRStreamingContext, x: torch.Tensor + ) -> Tuple[List[str], List[List[int]]]: + """Decodes the output of the encoder into tokens and the associated + transcription. + Must be called over a given context in the correct order of chunks over + time. + + Arguments + --------- + context : ASRStreamingContext + Mutable streaming context object, which should be the same object + that was passed to `encode_chunk`. + + x : torch.Tensor + The output of `encode_chunk` for a given chunk. + + Returns + ------- + list of str + Decoded tokens of length `batch_size`. The decoded strings can be + of 0-length. + list of list of output token hypotheses + List of length `batch_size`, each holding a list of tokens of any + length `>=0`. + """ + tokens = self.hparams.decoding_function(x, context.decoder_context) + + # initialize token context for real now that we know the batch size + if context.tokenizer_context is None: + context.tokenizer_context = [ + self.hparams.make_tokenizer_streaming_context() + for _ in range(len(tokens)) + ] + + words = [ + self.hparams.tokenizer_decode_streaming( + self.hparams.tokenizer, cur_tokens, context.tokenizer_context[i] + ) + for i, cur_tokens in enumerate(tokens) + ] + + return words, tokens + + def transcribe_chunk( + self, + context: ASRStreamingContext, + chunk: torch.Tensor, + chunk_len: Optional[torch.Tensor] = None, + ): + """Transcription of a batch of audio chunks into transcribed text. + Must be called over a given context in the correct order of chunks over + time. + + Arguments + --------- + context : ASRStreamingContext + Mutable streaming context object, which must be specified and reused + across calls when streaming. + You can obtain an initial context by calling + `asr.make_streaming_context(config)`. + chunk : torch.Tensor + The tensor for an audio chunk of shape `[batch size, time]`. + The time dimension must strictly match + `asr.get_chunk_size_frames(config)`. + The waveform is expected to be in the model's expected format (i.e. + the sampling rate must be correct). + chunk_len : torch.Tensor, optional + The relative chunk length tensor of shape `[batch size]`. This is to + be used when the audio in one of the chunks of the batch is ending + within this chunk. + If unspecified, equivalent to `torch.ones((batch_size,))`. + + Returns + ------- + str + Transcribed string for this chunk, might be of length zero. + """ + + if chunk_len is None: + chunk_len = torch.ones((chunk.size(0),)) + + chunk = chunk.float() + chunk, chunk_len = chunk.to(self.device), chunk_len.to(self.device) + + x = self.encode_chunk(context, chunk, chunk_len) + words, _ = self.decode_chunk(context, x) + + return words + + +class SpeechLLMASR(Pretrained): + """A ready-to-use SpeechLLM ASR model interface. + + The class can be used to run the entire speechllm model. + First, the audio is encoded into a sequence of hidden states using the `speech_encoder`. + Then, the hidden states are downsampled using the `feat_downsampler` and projected using the `proj` module. + The projected features are concatenated with the text embeddings and passed to the `searcher` module. + The `searcher` module returns the predicted tokens and the predicted words using an LLM decoder. + + The given YAML must contains the fields specified in the HPARAMS_NEEDED list. + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> from speechbrain.inference.ASR import SpeechLLMASR + >>> tmpdir = getfixture("tmpdir") + >>> asr_model = SpeechLLMASR.from_hparams( + ... source="speechbrain/asr-speechllm-librispeech", + ... savedir=tmpdir, + ... ) # doctest: +SKIP + >>> hyp = asr_model.transcribe_file( + ... "speechbrain/asr-speechllm-librispeech/example-en.wav" + ... ) # doctest: +SKIP + >>> hyp # doctest: +SKIP + THE BIRCH CANOE SLID ON THE SMOOTH PLANKS + """ + + HPARAMS_NEEDED = ["bos_index", "eos_index", "prompt"] + MODULES_NEEDED = [ + "speech_encoder", + "feat_downsampler", + "proj", + "llm", + "normalize", + "searcher", + ] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokenizer = self.mods.llm.tokenizer + self.txt_embedding = self.mods.llm.model.get_input_embeddings() + + def build_multimodal_embds(self, audio_feats): + """Builds the multimodal embeddings for the audio features.""" + prompt_ids = ( + self.tokenizer( + self.hparams.prompt, + return_tensors="pt", + add_special_tokens=False, + ) + .input_ids.view(-1) + .tolist() + ) + start_of_audio_token = "<|start_of_audio|>" + end_of_audio_token = "<|end_of_audio|>" + start_of_audio_index = self.tokenizer.convert_tokens_to_ids( + start_of_audio_token + ) + end_of_audio_index = self.tokenizer.convert_tokens_to_ids( + end_of_audio_token + ) + prompt_ids = torch.LongTensor( + [start_of_audio_index] + + [end_of_audio_index] + + prompt_ids + + [self.hparams.bos_index] + ).to(audio_feats.device) + prompt_embds = ( + self.txt_embedding(prompt_ids) + .unsqueeze(0) + .repeat(audio_feats.size(0), 1, 1) + ) + multimodal_embds = torch.cat( + [ + prompt_embds[:, 0].unsqueeze(1), # B, D -> B, 1, D + audio_feats, + prompt_embds[:, 1:], + ], + dim=1, + ) + attention_mask = torch.ones( + multimodal_embds.size(0), + multimodal_embds.size(1), + dtype=torch.bool, + device=multimodal_embds.device, + ) + return multimodal_embds, attention_mask + + @torch.no_grad() + def encode_batch(self, wavs, wav_lens): + """Encodes the audio waveforms into a sequence of hidden states. + By default, the `self.inference_ctx` is used to run the forward pass. + Can be overridden by passing a custom `--precision` argument. + + Arguments + --------- + wavs : torch.Tensor + The audio waveforms of shape (batch_size, time). + wav_lens : torch.Tensor + The lengths of the audio waveforms of shape (batch_size,). + + Returns + ------- + audio_feats : torch.Tensor + The encoded audio features of shape (batch_size, time, feat_dim). + """ + with self.inference_ctx: + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + wavs = self.mods.normalize(wavs, wav_lens) + audio_feats = self.mods.speech_encoder(wavs, wav_lens) + return audio_feats + + @torch.no_grad() + def transcribe_batch(self, wavs, wav_lens): + """Transcribes the input audio into a sequence of words. + + Arguments + --------- + wavs : torch.Tensor + The audio waveforms of shape (batch_size, time). + wav_lens : torch.Tensor + The lengths of the audio waveforms of shape (batch_size,). + + Returns + ------- + predicted_words : list + The predicted words of shape (batch_size,). + predicted_tokens : list + The predicted tokens of shape (batch_size,). + """ + with self.inference_ctx: + encoder_out = self.encode_batch(wavs, wav_lens) + audio_down_feats = self.mods.feat_downsampler(encoder_out) + audio_feats = self.mods.proj(audio_down_feats) + multimodal_embds, attention_mask = self.build_multimodal_embds( + audio_feats + ) + # Use the precision configured in self.inference_ctx, defaulting to float32 if not set + target_precision = getattr( + self.inference_ctx, "precision", torch.float32 + ) + hyps = self.mods.searcher( + multimodal_embds.to(target_precision), wav_lens, attention_mask + ) + predicted_tokens = hyps[0] + predicted_words = self.tokenizer.batch_decode( + predicted_tokens, skip_special_tokens=True + ) + return predicted_words, predicted_tokens + + def transcribe_file(self, path, **kwargs): + """Transcribe the given audio file into a sequence of words. + + Arguments + --------- + path : str + The path to the audio file. + **kwargs : dict + Arguments forwarded to `self.load_audio`. + + Returns + ------- + predicted_words : str + The predicted words of the audio file. + """ + waveform = self.load_audio(path, **kwargs) + batch = waveform.unsqueeze(0) + rel_length = torch.tensor([1.0]) + predicted_words, predicted_tokens = self.transcribe_batch( + batch, rel_length + ) + return predicted_words[0] + + def forward(self, wavs, wav_lens): + """Runs full batch decoding""" + return self.transcribe_batch(wavs, wav_lens) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/SLU.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/SLU.py new file mode 100644 index 000000000..e9132609d --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/SLU.py @@ -0,0 +1,144 @@ +"""Specifies the inference interfaces for Spoken Language Understanding (SLU) modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import torch + +from speechbrain.inference.ASR import EncoderDecoderASR +from speechbrain.inference.interfaces import Pretrained + + +class EndToEndSLU(Pretrained): + """An end-to-end SLU model. + + The class can be used either to run only the encoder (encode()) to extract + features or to run the entire model (decode()) to map the speech to its semantics. + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> from speechbrain.inference.SLU import EndToEndSLU + >>> tmpdir = getfixture("tmpdir") + >>> slu_model = EndToEndSLU.from_hparams( + ... source="speechbrain/slu-timers-and-such-direct-librispeech-asr", + ... savedir=tmpdir, + ... ) # doctest: +SKIP + >>> slu_model.decode_file( + ... "tests/samples/single-mic/example6.wav" + ... ) # doctest: +SKIP + "{'intent': 'SimpleMath', 'slots': {'number1': 37.67, 'number2': 75.7, 'op': ' minus '}}" + """ + + HPARAMS_NEEDED = ["tokenizer", "asr_model_source"] + MODULES_NEEDED = ["slu_enc", "beam_searcher"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokenizer = self.hparams.tokenizer + self.asr_model = EncoderDecoderASR.from_hparams( + source=self.hparams.asr_model_source, + run_opts={"device": self.device}, + ) + + def decode_file(self, path, **kwargs): + """Maps the given audio file to a string representing the + semantic dictionary for the utterance. + + Arguments + --------- + path : str + Path to audio file to decode. + **kwargs : dict + Arguments forwarded to ``load_audio``. + + Returns + ------- + str + The predicted semantics. + """ + waveform = self.load_audio(path, **kwargs) + waveform = waveform.to(self.device) + # Fake a batch: + batch = waveform.unsqueeze(0) + rel_length = torch.tensor([1.0]) + predicted_words, predicted_tokens = self.decode_batch(batch, rel_length) + return predicted_words[0] + + def encode_batch(self, wavs, wav_lens): + """Encodes the input audio into a sequence of hidden states + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + torch.Tensor + The encoded batch + """ + wavs = wavs.float() + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + ASR_encoder_out = self.asr_model.encode_batch(wavs.detach(), wav_lens) + encoder_out = self.mods.slu_enc(ASR_encoder_out) + return encoder_out + + def decode_batch(self, wavs, wav_lens): + """Maps the input audio to its semantics + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + list + Each waveform in the batch decoded. + tensor + Each predicted token id. + """ + with torch.no_grad(): + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + encoder_out = self.encode_batch(wavs, wav_lens) + predicted_tokens, scores, _, _ = self.mods.beam_searcher( + encoder_out, wav_lens + ) + predicted_words = [ + self.tokenizer.decode_ids(token_seq) + for token_seq in predicted_tokens + ] + return predicted_words, predicted_tokens + + def forward(self, wavs, wav_lens): + """Runs full decoding - note: no gradients through decoding""" + return self.decode_batch(wavs, wav_lens) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/ST.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/ST.py new file mode 100644 index 000000000..427a428af --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/ST.py @@ -0,0 +1,138 @@ +"""Specifies the inference interfaces for Speech Translation (ST) modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import torch + +from speechbrain.inference.interfaces import Pretrained + + +class EncoderDecoderS2UT(Pretrained): + """A ready-to-use Encoder Decoder for speech-to-unit translation model + + The class can be used to run the entire encoder-decoder S2UT model + (translate_file()) to translate speech. The given YAML must contains the fields + specified in the *_NEEDED[] lists. + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> from speechbrain.inference.ST import EncoderDecoderS2UT + >>> tmpdir = getfixture("tmpdir") + >>> s2ut_model = EncoderDecoderS2UT.from_hparams( + ... source="speechbrain/s2st-transformer-fr-en-hubert-l6-k100-cvss", + ... savedir=tmpdir, + ... ) # doctest: +SKIP + >>> s2ut_model.translate_file( + ... "speechbrain/s2st-transformer-fr-en-hubert-l6-k100-cvss/example-fr.wav" + ... ) # doctest: +SKIP + """ + + HPARAMS_NEEDED = ["sample_rate"] + MODULES_NEEDED = ["encoder", "decoder"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.sample_rate = self.hparams.sample_rate + + def translate_file(self, path): + """Translates the given audiofile into a sequence speech unit. + + Arguments + --------- + path : str + Path to audio file which to translate. + + Returns + ------- + int[] + The audiofile translation produced by this speech-to-unit translationmodel. + """ + + audio = self.load_audio(path) + audio = audio.to(self.device) + # Fake a batch: + batch = audio.unsqueeze(0) + rel_length = torch.tensor([1.0]) + predicted_tokens = self.translate_batch(batch, rel_length) + return predicted_tokens[0] + + def encode_batch(self, wavs, wav_lens): + """Encodes the input audio into a sequence of hidden states + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderDecoderS2UT.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.tensor + Batch of waveforms [batch, time, channels]. + wav_lens : torch.tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + torch.tensor + The encoded batch + """ + wavs = wavs.float() + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + encoder_out = self.mods.encoder(wavs, wav_lens) + return encoder_out + + def translate_batch(self, wavs, wav_lens): + """Translates the input audio into a sequence of words + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderDecoderS2UT.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.tensor + Batch of waveforms [batch, time, channels]. + wav_lens : torch.tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + list + Each waveform in the batch translated. + tensor + Each predicted token id. + """ + with torch.no_grad(): + wav_lens = wav_lens.to(self.device) + encoder_out = self.encode_batch(wavs, wav_lens) + predicted_tokens, _, _, _ = self.mods.decoder(encoder_out, wav_lens) + return predicted_tokens + + def forward(self, wavs, wav_lens): + """Runs full translation""" + return self.encode_batch(wavs, wav_lens) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/TTS.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/TTS.py new file mode 100644 index 000000000..c6c3137ed --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/TTS.py @@ -0,0 +1,928 @@ +"""Specifies the inference interfaces for Text-To-Speech (TTS) modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import random +import re + +import torch +import torchaudio + +import speechbrain +from speechbrain.dataio import audio_io +from speechbrain.inference.classifiers import EncoderClassifier +from speechbrain.inference.encoders import MelSpectrogramEncoder +from speechbrain.inference.interfaces import Pretrained +from speechbrain.inference.text import GraphemeToPhoneme +from speechbrain.utils.fetching import fetch +from speechbrain.utils.logger import get_logger +from speechbrain.utils.text_to_sequence import text_to_sequence + +logger = get_logger(__name__) + + +class Tacotron2(Pretrained): + """ + A ready-to-use wrapper for Tacotron2 (text -> mel_spec). + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> tmpdir_tts = getfixture("tmpdir") / "tts" + >>> tacotron2 = Tacotron2.from_hparams( + ... source="speechbrain/tts-tacotron2-ljspeech", savedir=tmpdir_tts + ... ) + >>> mel_output, mel_length, alignment = tacotron2.encode_text( + ... "Mary had a little lamb" + ... ) + >>> items = [ + ... "A quick brown fox jumped over the lazy dog", + ... "How much wood would a woodchuck chuck?", + ... "Never odd or even", + ... ] + >>> mel_outputs, mel_lengths, alignments = tacotron2.encode_batch(items) + + >>> # One can combine the TTS model with a vocoder (that generates the final waveform) + >>> # Initialize the Vocoder (HiFIGAN) + >>> tmpdir_vocoder = getfixture("tmpdir") / "vocoder" + >>> from speechbrain.inference.vocoders import HIFIGAN + >>> hifi_gan = HIFIGAN.from_hparams( + ... source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder + ... ) + >>> # Running the TTS + >>> mel_output, mel_length, alignment = tacotron2.encode_text( + ... "Mary had a little lamb" + ... ) + >>> # Running Vocoder (spectrogram-to-waveform) + >>> waveforms = hifi_gan.decode_batch(mel_output) + """ + + HPARAMS_NEEDED = ["model", "text_to_sequence"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.text_cleaners = getattr( + self.hparams, "text_cleaners", ["english_cleaners"] + ) + self.infer = self.hparams.model.infer + + def text_to_seq(self, txt): + """Encodes raw text into a tensor with a customer text-to-sequence function""" + sequence = self.hparams.text_to_sequence(txt, self.text_cleaners) + return sequence, len(sequence) + + def encode_batch(self, texts): + """Computes mel-spectrogram for a list of texts + + Texts must be sorted in decreasing order on their lengths + + Arguments + --------- + texts: List[str] + texts to be encoded into spectrogram + + Returns + ------- + tensors of output spectrograms, output lengths and alignments + """ + with torch.no_grad(): + inputs = [ + { + "text_sequences": torch.tensor( + self.text_to_seq(item)[0], device=self.device + ) + } + for item in texts + ] + inputs = speechbrain.dataio.batch.PaddedBatch(inputs) + + lens = [self.text_to_seq(item)[1] for item in texts] + assert lens == sorted(lens, reverse=True), ( + "input lengths must be sorted in decreasing order" + ) + input_lengths = torch.tensor(lens, device=self.device) + + mel_outputs_postnet, mel_lengths, alignments = self.infer( + inputs.text_sequences.data, input_lengths + ) + return mel_outputs_postnet, mel_lengths, alignments + + def encode_text(self, text): + """Runs inference for a single text str""" + return self.encode_batch([text]) + + def forward(self, texts): + "Encodes the input texts." + return self.encode_batch(texts) + + +class MSTacotron2(Pretrained): + """ + A ready-to-use wrapper for Zero-Shot Multi-Speaker Tacotron2. + For voice cloning: (text, reference_audio) -> (mel_spec). + For generating a random speaker voice: (text) -> (mel_spec). + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> tmpdir_tts = getfixture("tmpdir") / "tts" + >>> mstacotron2 = MSTacotron2.from_hparams( + ... source="speechbrain/tts-mstacotron2-libritts", savedir=tmpdir_tts + ... ) # doctest: +SKIP + >>> # Sample rate of the reference audio must be greater or equal to the sample rate of the speaker embedding model + >>> reference_audio_path = "tests/samples/single-mic/example1.wav" + >>> input_text = "Mary had a little lamb." + >>> mel_output, mel_length, alignment = mstacotron2.clone_voice( + ... input_text, reference_audio_path + ... ) # doctest: +SKIP + >>> # One can combine the TTS model with a vocoder (that generates the final waveform) + >>> # Initialize the Vocoder (HiFIGAN) + >>> tmpdir_vocoder = getfixture("tmpdir") / "vocoder" + >>> from speechbrain.inference.vocoders import HIFIGAN + >>> hifi_gan = HIFIGAN.from_hparams( + ... source="speechbrain/tts-hifigan-libritts-22050Hz", + ... savedir=tmpdir_vocoder, + ... ) # doctest: +SKIP + >>> # Running the TTS + >>> mel_output, mel_length, alignment = mstacotron2.clone_voice( + ... input_text, reference_audio_path + ... ) # doctest: +SKIP + >>> # Running Vocoder (spectrogram-to-waveform) + >>> waveforms = hifi_gan.decode_batch(mel_output) # doctest: +SKIP + >>> # For generating a random speaker voice, use the following + >>> mel_output, mel_length, alignment = mstacotron2.generate_random_voice( + ... input_text + ... ) # doctest: +SKIP + """ + + HPARAMS_NEEDED = ["model"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.text_cleaners = ["english_cleaners"] + self.infer = self.hparams.model.infer + self.custom_mel_spec_encoder = self.hparams.custom_mel_spec_encoder + + self.g2p = GraphemeToPhoneme.from_hparams( + self.hparams.g2p, run_opts={"device": self.device} + ) + + self.spk_emb_encoder = None + if self.custom_mel_spec_encoder: + self.spk_emb_encoder = MelSpectrogramEncoder.from_hparams( + source=self.hparams.spk_emb_encoder, + run_opts={"device": self.device}, + ) + else: + self.spk_emb_encoder = EncoderClassifier.from_hparams( + source=self.hparams.spk_emb_encoder, + run_opts={"device": self.device}, + ) + + def __text_to_seq(self, txt): + """Encodes raw text into a tensor with a customer text-to-sequence function""" + sequence = text_to_sequence(txt, self.text_cleaners) + return sequence, len(sequence) + + def clone_voice(self, texts, audio_path): + """ + Generates mel-spectrogram using input text and reference audio + + Arguments + --------- + texts : str or list + Input text + audio_path : str + Reference audio + + Returns + ------- + tensors of output spectrograms, output lengths and alignments + """ + + # Loads audio + ref_signal, signal_sr = audio_io.load(audio_path) + + # Resamples the audio if required + if signal_sr != self.hparams.spk_emb_sample_rate: + ref_signal = torchaudio.functional.resample( + ref_signal, signal_sr, self.hparams.spk_emb_sample_rate + ) + ref_signal = ref_signal.to(self.device) + + # Computes speaker embedding + if self.custom_mel_spec_encoder: + spk_emb = self.spk_emb_encoder.encode_waveform(ref_signal) + else: + spk_emb = self.spk_emb_encoder.encode_batch(ref_signal) + + spk_emb = spk_emb.squeeze(0) + + # Converts input texts into the corresponding phoneme sequences + if isinstance(texts, str): + texts = [texts] + phoneme_seqs = self.g2p(texts) + for i in range(len(phoneme_seqs)): + phoneme_seqs[i] = " ".join(phoneme_seqs[i]) + phoneme_seqs[i] = "{" + phoneme_seqs[i] + "}" + + # Repeats the speaker embedding to match the number of input texts + spk_embs = spk_emb.repeat(len(texts), 1) + + # Calls __encode_batch to generate the mel-spectrograms + return self.__encode_batch(phoneme_seqs, spk_embs) + + def generate_random_voice(self, texts): + """ + Generates mel-spectrogram using input text and a random speaker voice + + Arguments + --------- + texts : str or list + Input text + + Returns + ------- + tensors of output spectrograms, output lengths and alignments + """ + + spk_emb = self.__sample_random_speaker().float() + spk_emb = spk_emb.to(self.device) + + # Converts input texts into the corresponding phoneme sequences + if isinstance(texts, str): + texts = [texts] + phoneme_seqs = self.g2p(texts) + for i in range(len(phoneme_seqs)): + phoneme_seqs[i] = " ".join(phoneme_seqs[i]) + phoneme_seqs[i] = "{" + phoneme_seqs[i] + "}" + + # Repeats the speaker embedding to match the number of input texts + spk_embs = spk_emb.repeat(len(texts), 1) + + # Calls __encode_batch to generate the mel-spectrograms + return self.__encode_batch(phoneme_seqs, spk_embs) + + def __encode_batch(self, texts, spk_embs): + """Computes mel-spectrograms for a list of texts + Texts are sorted in decreasing order on their lengths + + Arguments + --------- + texts: List[str] + texts to be encoded into spectrogram + spk_embs: torch.Tensor + speaker embeddings + + Returns + ------- + tensors of output spectrograms, output lengths and alignments + """ + + with torch.no_grad(): + inputs = [ + { + "text_sequences": torch.tensor( + self.__text_to_seq(item)[0], device=self.device + ) + } + for item in texts + ] + + inputs = sorted( + inputs, + key=lambda x: x["text_sequences"].size()[0], + reverse=True, + ) + + lens = [entry["text_sequences"].size()[0] for entry in inputs] + + inputs = speechbrain.dataio.batch.PaddedBatch(inputs) + + assert lens == sorted(lens, reverse=True), ( + "input lengths must be sorted in decreasing order" + ) + input_lengths = torch.tensor(lens, device=self.device) + + mel_outputs_postnet, mel_lengths, alignments = self.infer( + inputs.text_sequences.data, spk_embs, input_lengths + ) + return mel_outputs_postnet, mel_lengths, alignments + + def __sample_random_speaker(self): + """Samples a random speaker embedding from a pretrained GMM + + Returns + ------- + x: torch.Tensor + A randomly sampled speaker embedding + """ + + # Fetches and Loads GMM trained on speaker embeddings + speaker_gmm_local_path = fetch( + filename=self.hparams.random_speaker_sampler, + source=self.hparams.random_speaker_sampler_source, + savedir=self.hparams.pretrainer.collect_in, + ) + random_speaker_gmm = torch.load(speaker_gmm_local_path) + gmm_n_components = random_speaker_gmm["gmm_n_components"] + gmm_means = random_speaker_gmm["gmm_means"] + gmm_covariances = random_speaker_gmm["gmm_covariances"] + + # Randomly selects a speaker + counts = torch.zeros(gmm_n_components) + counts[random.randint(0, gmm_n_components - 1)] = 1 + x = torch.empty(0, device=counts.device) + + # Samples an embedding for the speaker + for k in torch.arange(gmm_n_components)[counts > 0]: + # Considers full covariance type + d_k = torch.distributions.multivariate_normal.MultivariateNormal( + gmm_means[k], gmm_covariances[k] + ) + x_k = torch.stack([d_k.sample() for _ in range(int(counts[k]))]) + + x = torch.cat((x, x_k), dim=0) + + return x + + +class FastSpeech2(Pretrained): + """ + A ready-to-use wrapper for Fastspeech2 (text -> mel_spec). + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> tmpdir_tts = getfixture("tmpdir") / "tts" + >>> fastspeech2 = FastSpeech2.from_hparams( + ... source="speechbrain/tts-fastspeech2-ljspeech", savedir=tmpdir_tts + ... ) # doctest: +SKIP + >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text( + ... ["Mary had a little lamb."] + ... ) # doctest: +SKIP + >>> items = [ + ... "A quick brown fox jumped over the lazy dog", + ... "How much wood would a woodchuck chuck?", + ... "Never odd or even", + ... ] + >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text( + ... items + ... ) # doctest: +SKIP + >>> + >>> # One can combine the TTS model with a vocoder (that generates the final waveform) + >>> # Initialize the Vocoder (HiFIGAN) + >>> tmpdir_vocoder = getfixture("tmpdir") / "vocoder" + >>> from speechbrain.inference.vocoders import HIFIGAN + >>> hifi_gan = HIFIGAN.from_hparams( + ... source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder + ... ) # doctest: +SKIP + >>> # Running the TTS + >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text( + ... ["Mary had a little lamb."] + ... ) # doctest: +SKIP + >>> # Running Vocoder (spectrogram-to-waveform) + >>> waveforms = hifi_gan.decode_batch(mel_outputs) # doctest: +SKIP + """ + + HPARAMS_NEEDED = ["spn_predictor", "model", "input_encoder"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + lexicon = self.hparams.lexicon + lexicon = ["@@"] + lexicon + self.input_encoder = self.hparams.input_encoder + self.input_encoder.update_from_iterable(lexicon, sequence_input=False) + self.input_encoder.add_unk() + + self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") + + self.spn_token_encoded = ( + self.input_encoder.encode_sequence_torch(["spn"]).int().item() + ) + + def encode_text(self, texts, pace=1.0, pitch_rate=1.0, energy_rate=1.0): + """Computes mel-spectrogram for a list of texts + + Arguments + --------- + texts: List[str] + texts to be converted to spectrogram + pace: float + pace for the speech synthesis + pitch_rate : float + scaling factor for phoneme pitches + energy_rate : float + scaling factor for phoneme energies + + Returns + ------- + tensors of output spectrograms, output lengths and alignments + """ + + # Preprocessing required at the inference time for the input text + # "label" below contains input text + # "phoneme_labels" contain the phoneme sequences corresponding to input text labels + # "last_phonemes_combined" is used to indicate whether the index position is for a last phoneme of a word + # "punc_positions" is used to add back the silence for punctuations + phoneme_labels = list() + last_phonemes_combined = list() + punc_positions = list() + + for label in texts: + phoneme_label = list() + last_phonemes = list() + punc_position = list() + + words = label.split() + words = [word.strip() for word in words] + words_phonemes = self.g2p(words) + + for i in range(len(words_phonemes)): + words_phonemes_seq = words_phonemes[i] + for phoneme in words_phonemes_seq: + if not phoneme.isspace(): + phoneme_label.append(phoneme) + last_phonemes.append(0) + punc_position.append(0) + last_phonemes[-1] = 1 + if words[i][-1] in ":;-,.!?": + punc_position[-1] = 1 + + phoneme_labels.append(phoneme_label) + last_phonemes_combined.append(last_phonemes) + punc_positions.append(punc_position) + + # Inserts silent phonemes in the input phoneme sequence + all_tokens_with_spn = list() + max_seq_len = -1 + for i in range(len(phoneme_labels)): + phoneme_label = phoneme_labels[i] + token_seq = ( + self.input_encoder.encode_sequence_torch(phoneme_label) + .int() + .to(self.device) + ) + last_phonemes = torch.LongTensor(last_phonemes_combined[i]).to( + self.device + ) + + # Runs the silent phoneme predictor + spn_preds = ( + self.hparams.modules["spn_predictor"] + .infer(token_seq.unsqueeze(0), last_phonemes.unsqueeze(0)) + .int() + ) + + spn_to_add = torch.nonzero(spn_preds).reshape(-1).tolist() + + for j in range(len(punc_positions[i])): + if punc_positions[i][j] == 1: + spn_to_add.append(j) + + tokens_with_spn = list() + + for token_idx in range(token_seq.shape[0]): + tokens_with_spn.append(token_seq[token_idx].item()) + if token_idx in spn_to_add: + tokens_with_spn.append(self.spn_token_encoded) + + tokens_with_spn = torch.LongTensor(tokens_with_spn).to(self.device) + all_tokens_with_spn.append(tokens_with_spn) + if max_seq_len < tokens_with_spn.shape[-1]: + max_seq_len = tokens_with_spn.shape[-1] + + # "tokens_with_spn_tensor" holds the input phoneme sequence with silent phonemes + tokens_with_spn_tensor_padded = torch.LongTensor( + len(texts), max_seq_len + ).to(self.device) + tokens_with_spn_tensor_padded.zero_() + + for seq_idx, seq in enumerate(all_tokens_with_spn): + tokens_with_spn_tensor_padded[seq_idx, : len(seq)] = seq + + return self.encode_batch( + tokens_with_spn_tensor_padded, + pace=pace, + pitch_rate=pitch_rate, + energy_rate=energy_rate, + ) + + def encode_phoneme( + self, phonemes, pace=1.0, pitch_rate=1.0, energy_rate=1.0 + ): + """Computes mel-spectrogram for a list of phoneme sequences + + Arguments + --------- + phonemes: List[List[str]] + phonemes to be converted to spectrogram + pace: float + pace for the speech synthesis + pitch_rate : float + scaling factor for phoneme pitches + energy_rate : float + scaling factor for phoneme energies + + Returns + ------- + tensors of output spectrograms, output lengths and alignments + """ + + all_tokens = [] + max_seq_len = -1 + for phoneme in phonemes: + token_seq = ( + self.input_encoder.encode_sequence_torch(phoneme) + .int() + .to(self.device) + ) + if max_seq_len < token_seq.shape[-1]: + max_seq_len = token_seq.shape[-1] + all_tokens.append(token_seq) + + tokens_padded = torch.LongTensor(len(phonemes), max_seq_len).to( + self.device + ) + tokens_padded.zero_() + + for seq_idx, seq in enumerate(all_tokens): + tokens_padded[seq_idx, : len(seq)] = seq + + return self.encode_batch( + tokens_padded, + pace=pace, + pitch_rate=pitch_rate, + energy_rate=energy_rate, + ) + + def encode_batch( + self, tokens_padded, pace=1.0, pitch_rate=1.0, energy_rate=1.0 + ): + """Batch inference for a tensor of phoneme sequences + + Arguments + --------- + tokens_padded : torch.Tensor + A sequence of encoded phonemes to be converted to spectrogram + pace : float + pace for the speech synthesis + pitch_rate : float + scaling factor for phoneme pitches + energy_rate : float + scaling factor for phoneme energies + + Returns + ------- + post_mel_outputs : torch.Tensor + durations : torch.Tensor + pitch : torch.Tensor + energy : torch.Tensor + """ + with torch.no_grad(): + ( + _, + post_mel_outputs, + durations, + pitch, + _, + energy, + _, + _, + ) = self.hparams.model( + tokens_padded, + pace=pace, + pitch_rate=pitch_rate, + energy_rate=energy_rate, + ) + + # Transposes to make in compliant with HiFI GAN expected format + post_mel_outputs = post_mel_outputs.transpose(-1, 1) + + return post_mel_outputs, durations, pitch, energy + + def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0): + """Batch inference for a tensor of phoneme sequences + + Arguments + --------- + text : str + A text to be converted to spectrogram + pace : float + pace for the speech synthesis + pitch_rate : float + scaling factor for phoneme pitches + energy_rate : float + scaling factor for phoneme energies + + Returns + ------- + Encoded text + """ + return self.encode_text( + [text], pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate + ) + + +class FastSpeech2InternalAlignment(Pretrained): + """ + A ready-to-use wrapper for Fastspeech2 with internal alignment(text -> mel_spec). + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> tmpdir_tts = getfixture("tmpdir") / "tts" + >>> fastspeech2 = FastSpeech2InternalAlignment.from_hparams( + ... source="speechbrain/tts-fastspeech2-internal-alignment-ljspeech", + ... savedir=tmpdir_tts, + ... ) # doctest: +SKIP + >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text( + ... ["Mary had a little lamb."] + ... ) # doctest: +SKIP + >>> items = [ + ... "A quick brown fox jumped over the lazy dog", + ... "How much wood would a woodchuck chuck?", + ... "Never odd or even", + ... ] + >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text( + ... items + ... ) # doctest: +SKIP + >>> # One can combine the TTS model with a vocoder (that generates the final waveform) + >>> # Initialize the Vocoder (HiFIGAN) + >>> tmpdir_vocoder = getfixture("tmpdir") / "vocoder" + >>> from speechbrain.inference.vocoders import HIFIGAN + >>> hifi_gan = HIFIGAN.from_hparams( + ... source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder + ... ) # doctest: +SKIP + >>> # Running the TTS + >>> mel_outputs, durations, pitch, energy = fastspeech2.encode_text( + ... ["Mary had a little lamb."] + ... ) # doctest: +SKIP + >>> # Running Vocoder (spectrogram-to-waveform) + >>> waveforms = hifi_gan.decode_batch(mel_outputs) # doctest: +SKIP + """ + + HPARAMS_NEEDED = ["model", "input_encoder"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + lexicon = self.hparams.lexicon + lexicon = ["@@"] + lexicon + self.input_encoder = self.hparams.input_encoder + self.input_encoder.update_from_iterable(lexicon, sequence_input=False) + self.input_encoder.add_unk() + + self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") + + def encode_text(self, texts, pace=1.0, pitch_rate=1.0, energy_rate=1.0): + """Computes mel-spectrogram for a list of texts + + Arguments + --------- + texts: List[str] + texts to be converted to spectrogram + pace: float + pace for the speech synthesis + pitch_rate : float + scaling factor for phoneme pitches + energy_rate : float + scaling factor for phoneme energies + + Returns + ------- + tensors of output spectrograms, output lengths and alignments + """ + + # Preprocessing required at the inference time for the input text + # "label" below contains input text + # "phoneme_labels" contain the phoneme sequences corresponding to input text labels + + phoneme_labels = list() + max_seq_len = -1 + + for label in texts: + phonemes_with_punc = self._g2p_keep_punctuations(self.g2p, label) + if max_seq_len < len(phonemes_with_punc): + max_seq_len = len(phonemes_with_punc) + token_seq = ( + self.input_encoder.encode_sequence_torch(phonemes_with_punc) + .int() + .to(self.device) + ) + phoneme_labels.append(token_seq) + + tokens_padded = torch.LongTensor(len(texts), max_seq_len).to( + self.device + ) + tokens_padded.zero_() + + for seq_idx, seq in enumerate(phoneme_labels): + tokens_padded[seq_idx, : len(seq)] = seq + + return self.encode_batch( + tokens_padded, + pace=pace, + pitch_rate=pitch_rate, + energy_rate=energy_rate, + ) + + def _g2p_keep_punctuations(self, g2p_model, text): + """do grapheme to phoneme and keep the punctuations between the words""" + # find the words where a "-" or "'" or "." or ":" appears in the middle + special_words = re.findall(r"\w+[-':\.][-':\.\w]*\w+", text) + + # remove intra-word punctuations ("-':."), this does not change the output of speechbrain g2p + for special_word in special_words: + rmp = special_word.replace("-", "") + rmp = rmp.replace("'", "") + rmp = rmp.replace(":", "") + rmp = rmp.replace(".", "") + text = text.replace(special_word, rmp) + + # keep inter-word punctuations + all_ = re.findall(r"[\w]+|[-!'(),.:;? ]", text) + try: + phonemes = g2p_model(text) + except RuntimeError: + logger.info(f"error with text: {text}") + quit() + word_phonemes = "-".join(phonemes).split(" ") + + phonemes_with_punc = [] + count = 0 + try: + # if the g2p model splits the words correctly + for i in all_: + if i not in "-!'(),.:;? ": + phonemes_with_punc.extend(word_phonemes[count].split("-")) + count += 1 + else: + phonemes_with_punc.append(i) + except IndexError: + # sometimes the g2p model cannot split the words correctly + logger.warning( + f"Do g2p word by word because of unexpected outputs from g2p for text: {text}" + ) + + for i in all_: + if i not in "-!'(),.:;? ": + p = g2p_model.g2p(i) + p_without_space = [i for i in p if i != " "] + phonemes_with_punc.extend(p_without_space) + else: + phonemes_with_punc.append(i) + + while "" in phonemes_with_punc: + phonemes_with_punc.remove("") + return phonemes_with_punc + + def encode_phoneme( + self, phonemes, pace=1.0, pitch_rate=1.0, energy_rate=1.0 + ): + """Computes mel-spectrogram for a list of phoneme sequences + + Arguments + --------- + phonemes: List[List[str]] + phonemes to be converted to spectrogram + pace: float + pace for the speech synthesis + pitch_rate : float + scaling factor for phoneme pitches + energy_rate : float + scaling factor for phoneme energies + + Returns + ------- + tensors of output spectrograms, output lengths and alignments + """ + + all_tokens = [] + max_seq_len = -1 + for phoneme in phonemes: + token_seq = ( + self.input_encoder.encode_sequence_torch(phoneme) + .int() + .to(self.device) + ) + if max_seq_len < token_seq.shape[-1]: + max_seq_len = token_seq.shape[-1] + all_tokens.append(token_seq) + + tokens_padded = torch.LongTensor(len(phonemes), max_seq_len).to( + self.device + ) + tokens_padded.zero_() + + for seq_idx, seq in enumerate(all_tokens): + tokens_padded[seq_idx, : len(seq)] = seq + + return self.encode_batch( + tokens_padded, + pace=pace, + pitch_rate=pitch_rate, + energy_rate=energy_rate, + ) + + def encode_batch( + self, tokens_padded, pace=1.0, pitch_rate=1.0, energy_rate=1.0 + ): + """Batch inference for a tensor of phoneme sequences + + Arguments + --------- + tokens_padded : torch.Tensor + A sequence of encoded phonemes to be converted to spectrogram + pace : float + pace for the speech synthesis + pitch_rate : float + scaling factor for phoneme pitches + energy_rate : float + scaling factor for phoneme energies + + Returns + ------- + post_mel_outputs : torch.Tensor + durations : torch.Tensor + pitch : torch.Tensor + energy : torch.Tensor + """ + with torch.no_grad(): + ( + _, + post_mel_outputs, + durations, + pitch, + _, + energy, + _, + _, + _, + _, + _, + _, + ) = self.hparams.model( + tokens_padded, + pace=pace, + pitch_rate=pitch_rate, + energy_rate=energy_rate, + ) + + # Transposes to make in compliant with HiFI GAN expected format + post_mel_outputs = post_mel_outputs.transpose(-1, 1) + + return post_mel_outputs, durations, pitch, energy + + def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0): + """Batch inference for a tensor of phoneme sequences + + Arguments + --------- + text : str + A text to be converted to spectrogram + pace : float + pace for the speech synthesis + pitch_rate : float + scaling factor for phoneme pitches + energy_rate : float + scaling factor for phoneme energies + + Returns + ------- + Encoded text + """ + return self.encode_text( + [text], pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate + ) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/VAD.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/VAD.py new file mode 100644 index 000000000..968647ab3 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/VAD.py @@ -0,0 +1,965 @@ +"""Specifies the inference interfaces for Voice Activity Detection (VAD) modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import torch + +from speechbrain.dataio import audio_io +from speechbrain.inference.interfaces import Pretrained +from speechbrain.utils.data_utils import split_path +from speechbrain.utils.fetching import fetch + + +class VAD(Pretrained): + """A ready-to-use class for Voice Activity Detection (VAD) using a + pre-trained model. + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> import torchaudio + >>> from speechbrain.inference.VAD import VAD + >>> # Model is downloaded from the speechbrain HuggingFace repo + >>> tmpdir = getfixture("tmpdir") + >>> VAD = VAD.from_hparams( + ... source="speechbrain/vad-crdnn-libriparty", + ... savedir=tmpdir, + ... ) + + >>> # Perform VAD + >>> boundaries = VAD.get_speech_segments( + ... "tests/samples/single-mic/example1.wav" + ... ) + """ + + HPARAMS_NEEDED = ["sample_rate", "time_resolution", "device"] + + MODULES_NEEDED = ["compute_features", "mean_var_norm", "model"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.time_resolution = self.hparams.time_resolution + self.sample_rate = self.hparams.sample_rate + + def get_speech_prob_file( + self, + audio_file, + large_chunk_size=30, + small_chunk_size=10, + overlap_small_chunk=False, + ): + """Outputs the frame-level speech probability of the input audio file + using the neural model specified in the hparam file. To make this code + both parallelizable and scalable to long sequences, it uses a + double-windowing approach. First, we sequentially read non-overlapping + large chunks of the input signal. We then split the large chunks into + smaller chunks and we process them in parallel. + + Arguments + --------- + audio_file: path + Path of the audio file containing the recording. The file is read + with torchaudio. + large_chunk_size: float + Size (in seconds) of the large chunks that are read sequentially + from the input audio file. + small_chunk_size: float + Size (in seconds) of the small chunks extracted from the large ones. + The audio signal is processed in parallel within the small chunks. + Note that large_chunk_size/small_chunk_size must be an integer. + overlap_small_chunk: bool + True, creates overlapped small chunks. The probabilities of the + overlapped chunks are combined using hamming windows. + + Returns + ------- + prob_vad: torch.Tensor + torch.Tensor containing the frame-level speech probabilities for the + input audio file. + """ + # Getting the total size of the input file + sample_rate, audio_len = self._get_audio_info(audio_file) + + if sample_rate != self.sample_rate: + raise ValueError( + "The detected sample rate is different from that set in the hparam file" + ) + + # Computing the length (in samples) of the large and small chunks + long_chunk_len = int(sample_rate * large_chunk_size) + small_chunk_len = int(sample_rate * small_chunk_size) + + # Setting the step size of the small chunk (50% overlapping windows are supported) + small_chunk_step = small_chunk_size + if overlap_small_chunk: + small_chunk_step = small_chunk_size / 2 + + # Computing the length (in sample) of the small_chunk step size + small_chunk_len_step = int(sample_rate * small_chunk_step) + + # Loop over big chunks + prob_chunks = [] + last_chunk = False + begin_sample = 0 + while True: + # Check if the current chunk is the last one + if begin_sample + long_chunk_len >= audio_len: + last_chunk = True + + # Reading the big chunk + large_chunk, fs = audio_io.load( + str(audio_file), + frame_offset=begin_sample, + num_frames=long_chunk_len, + ) + large_chunk = large_chunk.to(self.device) + + # Manage padding of the last small chunk + if last_chunk or large_chunk.shape[-1] < small_chunk_len: + padding = torch.zeros( + 1, small_chunk_len, device=large_chunk.device + ) + large_chunk = torch.cat([large_chunk, padding], dim=1) + + # Splitting the big chunk into smaller (overlapped) ones + small_chunks = torch.nn.functional.unfold( + large_chunk.unsqueeze(1).unsqueeze(2), + kernel_size=(1, small_chunk_len), + stride=(1, small_chunk_len_step), + ) + small_chunks = small_chunks.squeeze(0).transpose(0, 1) + + # Getting (in parallel) the frame-level speech probabilities + small_chunks_prob = self.get_speech_prob_chunk(small_chunks) + small_chunks_prob = small_chunks_prob[:, :-1, :] + + # Manage overlapping chunks + if overlap_small_chunk: + small_chunks_prob = self._manage_overlapped_chunks( + small_chunks_prob + ) + + # Prepare for folding + small_chunks_prob = small_chunks_prob.permute(2, 1, 0) + + # Computing lengths in samples + out_len = int( + large_chunk.shape[-1] / (sample_rate * self.time_resolution) + ) + kernel_len = int(small_chunk_size / self.time_resolution) + step_len = int(small_chunk_step / self.time_resolution) + + # Folding the frame-level predictions + small_chunks_prob = torch.nn.functional.fold( + small_chunks_prob, + output_size=(1, out_len), + kernel_size=(1, kernel_len), + stride=(1, step_len), + ) + + # Appending the frame-level speech probabilities of the large chunk + small_chunks_prob = small_chunks_prob.squeeze(1).transpose(-1, -2) + prob_chunks.append(small_chunks_prob) + + # Check stop condition + if last_chunk: + break + + # Update counter to process the next big chunk + begin_sample = begin_sample + long_chunk_len + + # Converting the list to a tensor + prob_vad = torch.cat(prob_chunks, dim=1) + last_elem = int(audio_len / (self.time_resolution * sample_rate)) + prob_vad = prob_vad[:, 0:last_elem, :] + + return prob_vad + + def _manage_overlapped_chunks(self, small_chunks_prob): + """This support function manages overlapped the case in which the + small chunks have a 50% overlap.""" + + # Weighting the frame-level probabilities with a hamming window + # reduces uncertainty when overlapping chunks are used. + hamming_window = torch.hamming_window( + small_chunks_prob.shape[1], device=self.device + ) + + # First and last chunks require special care + half_point = int(small_chunks_prob.shape[1] / 2) + small_chunks_prob[0, half_point:] = small_chunks_prob[ + 0, half_point: + ] * hamming_window[half_point:].unsqueeze(1) + small_chunks_prob[-1, 0:half_point] = small_chunks_prob[ + -1, 0:half_point + ] * hamming_window[0:half_point].unsqueeze(1) + + # Applying the window to all the other probabilities + small_chunks_prob[1:-1] = small_chunks_prob[ + 1:-1 + ] * hamming_window.unsqueeze(0).unsqueeze(2) + + return small_chunks_prob + + def get_speech_prob_chunk(self, wavs, wav_lens=None): + """Outputs the frame-level posterior probability for the input audio chunks + Outputs close to zero refers to time steps with a low probability of speech + activity, while outputs closer to one likely contain speech. + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. Make sure the sample rate is fs=16000 Hz. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + torch.Tensor + The encoded batch + """ + # Manage single waveforms in input + if len(wavs.shape) == 1: + wavs = wavs.unsqueeze(0) + + # Assign full length if wav_lens is not assigned + if wav_lens is None: + wav_lens = torch.ones(wavs.shape[0], device=self.device) + + # Storing waveform in the specified device + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + wavs = wavs.float() + + # Computing features and embeddings + feats = self.mods.compute_features(wavs) + feats = self.mods.mean_var_norm(feats, wav_lens) + outputs = self.mods.cnn(feats) + + outputs = outputs.reshape( + outputs.shape[0], + outputs.shape[1], + outputs.shape[2] * outputs.shape[3], + ) + + outputs, h = self.mods.rnn(outputs) + outputs = self.mods.dnn(outputs) + output_prob = torch.sigmoid(outputs) + + return output_prob + + def apply_threshold( + self, vad_prob, activation_th=0.5, deactivation_th=0.25 + ): + """Scans the frame-level speech probabilities and applies a threshold + on them. Speech starts when a value larger than activation_th is + detected, while it ends when observing a value lower than + the deactivation_th. + + Arguments + --------- + vad_prob: torch.Tensor + Frame-level speech probabilities. + activation_th: float + Threshold for starting a speech segment. + deactivation_th: float + Threshold for ending a speech segment. + + Returns + ------- + vad_th: torch.BoolTensor + torch.Tensor containing 1 for speech regions and 0 for non-speech regions. + """ + # whether the n-th frame falls below threshold and triggers deactivation + frame_does_not_deactivate = (vad_prob >= deactivation_th).to("cpu") + + # always start keeping frames over activation threshold activated + vad_th = (vad_prob >= activation_th).to("cpu") + + for i in range(1, vad_prob.shape[1]): + # if the previous frame was activated, then keep it activated... + vad_th[:, i, ...] |= vad_th[:, i - 1, ...] + + # ... unless the i-th (current) frame is below threshold + vad_th[:, i, ...] &= frame_does_not_deactivate[:, i, ...] + + return vad_th.to(vad_prob.device) + + def get_boundaries(self, prob_th, output_value="seconds"): + """Computes the time boundaries where speech activity is detected. + It takes in input frame-level binary decisions + (1 for speech, 0 for non-speech) and outputs the begin/end second + (or sample) of each detected speech region. + + Arguments + --------- + prob_th: torch.Tensor + Frame-level binary decisions (1 for speech frame, 0 for a + non-speech one). The tensor can be obtained from apply_threshold. + output_value: 'seconds' or 'samples' + When the option 'seconds' is set, the returned boundaries are in + seconds, otherwise, it reports them in samples. + + Returns + ------- + boundaries: torch.Tensor + torch.Tensor containing the start second (or sample) of speech segments + in even positions and their corresponding end in odd positions + (e.g, [1.0, 1.5, 5,.0 6.0] means that we have two speech segment; + one from 1.0 to 1.5 seconds and another from 5.0 to 6.0 seconds). + """ + # Shifting frame-levels binary decision by 1 + # This allows detecting changes in speech/non-speech activities + prob_th_shifted = torch.roll(prob_th, dims=1, shifts=1) + prob_th_shifted[:, 0, :] = 0 + prob_th = prob_th + prob_th_shifted + + # Needed to first and last time step + prob_th[:, 0, :] = (prob_th[:, 0, :] >= 1).int() + prob_th[:, -1, :] = (prob_th[:, -1, :] >= 1).int() + + # Fix edge cases (when a speech starts in the last frames) + if (prob_th == 1).nonzero().shape[0] % 2 == 1: + prob_th = torch.cat( + ( + prob_th, + torch.Tensor([1.0]) + .unsqueeze(0) + .unsqueeze(2) + .to(self.device), + ), + dim=1, + ) + + # Where prob_th is 1 there is a change + indexes = (prob_th == 1).nonzero()[:, 1].reshape(-1, 2) + + # Remove 1 from end samples + indexes[:, -1] = indexes[:, -1] - 1 + + # From indexes to samples + seconds = (indexes * self.time_resolution).float() + samples = (self.sample_rate * seconds).round().int() + + if output_value == "seconds": + boundaries = seconds + else: + boundaries = samples + return boundaries + + def merge_close_segments(self, boundaries, close_th=0.250): + """Merges segments that are shorter than the given threshold. + + Arguments + --------- + boundaries : str + torch.Tensor containing the speech boundaries. It can be derived using the + get_boundaries method. + close_th: float + If the distance between boundaries is smaller than close_th, the + segments will be merged. + + Returns + ------- + new_boundaries + The new boundaries with the merged segments. + """ + + new_boundaries = [] + + # Single segment case + if boundaries.shape[0] == 0: + return boundaries + + # Getting beg and end of previous segment + prev_beg_seg = boundaries[0, 0].float() + prev_end_seg = boundaries[0, 1].float() + + # Process all the segments + for i in range(1, boundaries.shape[0]): + beg_seg = boundaries[i, 0] + segment_distance = beg_seg - prev_end_seg + + # Merging close segments + if segment_distance <= close_th: + prev_end_seg = boundaries[i, 1] + + else: + # Appending new segments + new_boundaries.append([prev_beg_seg, prev_end_seg]) + prev_beg_seg = beg_seg + prev_end_seg = boundaries[i, 1] + + new_boundaries.append([prev_beg_seg, prev_end_seg]) + new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device) + return new_boundaries + + def remove_short_segments(self, boundaries, len_th=0.250): + """Removes segments that are too short. + + Arguments + --------- + boundaries : torch.Tensor + torch.Tensor containing the speech boundaries. It can be derived using the + get_boundaries method. + len_th: float + If the length of the segment is smaller than close_th, the segments + will be merged. + + Returns + ------- + new_boundaries + The new boundaries without the short segments. + """ + new_boundaries = [] + + # Process the segments + for i in range(boundaries.shape[0]): + # Computing segment length + seg_len = boundaries[i, 1] - boundaries[i, 0] + + # Accept segment only if longer than len_th + if seg_len > len_th: + new_boundaries.append([boundaries[i, 0], boundaries[i, 1]]) + new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device) + + return new_boundaries + + def save_boundaries( + self, boundaries, save_path=None, print_boundaries=True, audio_file=None + ): + """Saves the boundaries on a file (and/or prints them) in a readable format. + + Arguments + --------- + boundaries: torch.Tensor + torch.Tensor containing the speech boundaries. It can be derived using the + get_boundaries method. + save_path: path + When to store the text file containing the speech/non-speech intervals. + print_boundaries: Bool + Prints the speech/non-speech intervals in the standard outputs. + audio_file: path + Path of the audio file containing the recording. The file is read + with torchaudio. It is used here to detect the length of the + signal. + """ + # Create a new file if needed + if save_path is not None: + f = open(save_path, mode="w", encoding="utf-8") + + # Getting the total size of the input file + if audio_file is not None: + sample_rate, audio_len = self._get_audio_info(audio_file) + audio_len = audio_len / sample_rate + + # Setting the rights format for second- or sample-based boundaries + if boundaries.dtype == torch.int: + value_format = "% i" + else: + value_format = "% .2f " + + # Printing speech and non-speech intervals + last_end = 0 + cnt_seg = 0 + for i in range(boundaries.shape[0]): + begin_value = boundaries[i, 0] + end_value = boundaries[i, 1] + + if last_end != begin_value: + cnt_seg = cnt_seg + 1 + print_str = ( + "segment_%03d " + value_format + value_format + "NON_SPEECH" + ) + if print_boundaries: + print(print_str % (cnt_seg, last_end, begin_value)) + if save_path is not None: + f.write(print_str % (cnt_seg, last_end, begin_value) + "\n") + + cnt_seg = cnt_seg + 1 + print_str = "segment_%03d " + value_format + value_format + "SPEECH" + if print_boundaries: + print(print_str % (cnt_seg, begin_value, end_value)) + if save_path is not None: + f.write(print_str % (cnt_seg, begin_value, end_value) + "\n") + + last_end = end_value + + # Managing last segment + if audio_file is not None: + if last_end < audio_len: + cnt_seg = cnt_seg + 1 + print_str = ( + "segment_%03d " + value_format + value_format + "NON_SPEECH" + ) + if print_boundaries: + print(print_str % (cnt_seg, end_value, audio_len)) + if save_path is not None: + f.write(print_str % (cnt_seg, end_value, audio_len) + "\n") + + if save_path is not None: + f.close() + + def energy_VAD( + self, + audio_file, + boundaries, + activation_th=0.5, + deactivation_th=0.0, + eps=1e-6, + ): + """Applies energy-based VAD within the detected speech segments.The neural + network VAD often creates longer segments and tends to merge segments that + are close with each other. + + The energy VAD post-processes can be useful for having a fine-grained voice + activity detection. + + The energy VAD computes the energy within the small chunks. The energy is + normalized within the segment to have mean 0.5 and +-0.5 of std. + This helps to set the energy threshold. + + Arguments + --------- + audio_file: path + Path of the audio file containing the recording. The file is read + with torchaudio. + boundaries: torch.Tensor + torch.Tensor containing the speech boundaries. It can be derived using the + get_boundaries method. + activation_th: float + A new speech segment is started it the energy is above activation_th. + deactivation_th: float + The segment is considered ended when the energy is <= deactivation_th. + eps: float + Small constant for numerical stability. + + Returns + ------- + new_boundaries + The new boundaries that are post-processed by the energy VAD. + """ + + # Getting the total size of the input file + sample_rate, audio_len = self._get_audio_info(audio_file) + + if sample_rate != self.sample_rate: + raise ValueError( + "The detected sample rate is different from that set in the hparam file" + ) + + # Computing the chunk length of the energy window + chunk_len = int(self.time_resolution * sample_rate) + new_boundaries = [] + + # Processing speech segments + for i in range(boundaries.shape[0]): + begin_sample = int(boundaries[i, 0] * sample_rate) + end_sample = int(boundaries[i, 1] * sample_rate) + seg_len = end_sample - begin_sample + + # Reading the speech segment + segment, _ = audio_io.load( + audio_file, frame_offset=begin_sample, num_frames=seg_len + ) + segment = segment.to(self.device) + # Create chunks + segment_chunks = self.create_chunks( + segment, chunk_size=chunk_len, chunk_stride=chunk_len + ) + + # Energy computation within each chunk + energy_chunks = segment_chunks.abs().sum(-1) + eps + energy_chunks = energy_chunks.log() + + # Energy normalization + energy_chunks = ( + (energy_chunks - energy_chunks.mean()) + / (2 * energy_chunks.std()) + ) + 0.5 + energy_chunks = energy_chunks.unsqueeze(0).unsqueeze(2) + + # Apply threshold based on the energy value + energy_vad = self.apply_threshold( + energy_chunks, + activation_th=activation_th, + deactivation_th=deactivation_th, + ) + + # Get the boundaries + energy_boundaries = self.get_boundaries( + energy_vad, output_value="seconds" + ) + + # Get the final boundaries in the original signal + for j in range(energy_boundaries.shape[0]): + start_en = boundaries[i, 0] + energy_boundaries[j, 0] + end_end = boundaries[i, 0] + energy_boundaries[j, 1] + new_boundaries.append([start_en, end_end]) + + # Convert boundaries to tensor + new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device) + return new_boundaries + + def create_chunks(self, x, chunk_size=16384, chunk_stride=16384): + """Splits the input into smaller chunks of size chunk_size with + an overlap chunk_stride. The chunks are concatenated over + the batch axis. + + Arguments + --------- + x: torch.Tensor + Signal to split into chunks. + chunk_size : int + The size of each chunk. + chunk_stride: int + The stride (hop) of each chunk. + + Returns + ------- + x: torch.Tensor + A new tensors with the chunks derived from the input signal. + """ + x = x.unfold(1, chunk_size, chunk_stride) + x = x.reshape(x.shape[0] * x.shape[1], -1) + return x + + def _get_audio_info(self, audio_file): + """Returns the sample rate and the length of the input audio file""" + + # Getting the total size of the input file + metadata = audio_io.info(str(audio_file)) + sample_rate = metadata.sample_rate + audio_len = metadata.num_frames + return sample_rate, audio_len + + def upsample_VAD(self, vad_out, audio_file, time_resolution=0.01): + """Upsamples the output of the vad to help visualization. It creates a + signal that is 1 when there is speech and 0 when there is no speech. + The vad signal has the same resolution as the input one and can be + opened with it (e.g, using audacity) to visually figure out VAD regions. + + Arguments + --------- + vad_out: torch.Tensor + torch.Tensor containing 1 for each frame of speech and 0 for each non-speech + frame. + audio_file: path + The original audio file used to compute vad_out + time_resolution : float + Time resolution of the vad_out signal. + + Returns + ------- + vad_signal + The upsampled version of the vad_out tensor. + """ + + # Getting the total size of the input file + sample_rate, sig_len = self._get_audio_info(audio_file) + + if sample_rate != self.sample_rate: + raise ValueError( + "The detected sample rate is different from that set in the hparam file" + ) + + beg_samp = 0 + step_size = int(time_resolution * sample_rate) + end_samp = step_size + index = 0 + + # Initialize upsampled signal + vad_signal = torch.zeros(1, sig_len, device=vad_out.device) + + # Upsample signal + while end_samp < sig_len: + vad_signal[0, beg_samp:end_samp] = vad_out[0, index, 0] + index = index + 1 + beg_samp = beg_samp + step_size + end_samp = beg_samp + step_size + return vad_signal + + def upsample_boundaries(self, boundaries, audio_file): + """Based on the input boundaries, this method creates a signal that is 1 + when there is speech and 0 when there is no speech. + The vad signal has the same resolution as the input one and can be + opened with it (e.g, using audacity) to visually figure out VAD regions. + + Arguments + --------- + boundaries: torch.Tensor + torch.Tensor containing the boundaries of the speech segments. + audio_file: path + The original audio file used to compute vad_out + + Returns + ------- + vad_signal + The output vad signal with the same resolution of the input one. + """ + + # Getting the total size of the input file + sample_rate, sig_len = self._get_audio_info(audio_file) + + if sample_rate != self.sample_rate: + raise ValueError( + "The detected sample rate is different from that set in the hparam file" + ) + + # Initialization of the output signal + vad_signal = torch.zeros(1, sig_len, device=boundaries.device) + + # Composing the vad signal from boundaries + for i in range(boundaries.shape[0]): + beg_sample = int(boundaries[i, 0] * sample_rate) + end_sample = int(boundaries[i, 1] * sample_rate) + vad_signal[0, beg_sample:end_sample] = 1.0 + return vad_signal + + def double_check_speech_segments( + self, boundaries, audio_file, speech_th=0.5 + ): + """Takes in input the boundaries of the detected speech segments and + double checks (using the neural VAD) that they actually contain speech. + + Arguments + --------- + boundaries: torch.Tensor + torch.Tensor containing the boundaries of the speech segments. + audio_file: path + The original audio file used to compute vad_out. + speech_th: float + Threshold on the mean posterior probability over which speech is + confirmed. Below that threshold, the segment is re-assigned to a + non-speech region. + + Returns + ------- + new_boundaries + The boundaries of the segments where speech activity is confirmed. + """ + + # Getting the total size of the input file + sample_rate, sig_len = self._get_audio_info(audio_file) + + # Double check the segments + new_boundaries = [] + for i in range(boundaries.shape[0]): + beg_sample = int(boundaries[i, 0] * sample_rate) + end_sample = int(boundaries[i, 1] * sample_rate) + len_seg = end_sample - beg_sample + + # Read the candidate speech segment + segment, fs = audio_io.load( + str(audio_file), frame_offset=beg_sample, num_frames=len_seg + ) + speech_prob = self.get_speech_prob_chunk(segment) + if speech_prob.mean() > speech_th: + # Accept this as a speech segment + new_boundaries.append([boundaries[i, 0], boundaries[i, 1]]) + + # Convert boundaries from list to tensor + new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device) + return new_boundaries + + def get_segments( + self, boundaries, audio_file, before_margin=0.1, after_margin=0.1 + ): + """Returns a list containing all the detected speech segments. + + Arguments + --------- + boundaries: torch.Tensor + torch.Tensor containing the boundaries of the speech segments. + audio_file: path + The original audio file used to compute vad_out. + before_margin: float + Used to cut the segments samples a bit before the detected margin. + after_margin: float + Use to cut the segments samples a bit after the detected margin. + + Returns + ------- + segments: list + List containing the detected speech segments + """ + sample_rate, sig_len = self._get_audio_info(audio_file) + + if sample_rate != self.sample_rate: + raise ValueError( + "The detected sample rate is different from that set in the hparam file" + ) + + segments = [] + for i in range(boundaries.shape[0]): + beg_sample = boundaries[i, 0] * sample_rate + end_sample = boundaries[i, 1] * sample_rate + + beg_sample = int(max(0, beg_sample - before_margin * sample_rate)) + end_sample = int( + min(sig_len, end_sample + after_margin * sample_rate) + ) + + len_seg = end_sample - beg_sample + vad_segment, fs = audio_io.load( + audio_file, frame_offset=beg_sample, num_frames=len_seg + ) + segments.append(vad_segment) + return segments + + def get_speech_segments( + self, + audio_file, + large_chunk_size=30, + small_chunk_size=10, + overlap_small_chunk=False, + apply_energy_VAD=False, + double_check=True, + close_th=0.250, + len_th=0.250, + activation_th=0.5, + deactivation_th=0.25, + en_activation_th=0.5, + en_deactivation_th=0.0, + speech_th=0.50, + ): + """Detects speech segments within the input file. The input signal can + be both a short or a long recording. The function computes the + posterior probabilities on large chunks (e.g, 30 sec), that are read + sequentially (to avoid storing big signals in memory). + Each large chunk is, in turn, split into smaller chunks (e.g, 10 seconds) + that are processed in parallel. The pipeline for detecting the speech + segments is the following: + 1- Compute posteriors probabilities at the frame level. + 2- Apply a threshold on the posterior probability. + 3- Derive candidate speech segments on top of that. + 4- Apply energy VAD within each candidate segment (optional). + 5- Merge segments that are too close. + 6- Remove segments that are too short. + 7- Double check speech segments (optional). + + Arguments + --------- + audio_file : str + Path to audio file. + large_chunk_size: float + Size (in seconds) of the large chunks that are read sequentially + from the input audio file. + small_chunk_size: float + Size (in seconds) of the small chunks extracted from the large ones. + The audio signal is processed in parallel within the small chunks. + Note that large_chunk_size/small_chunk_size must be an integer. + overlap_small_chunk: bool + If True, it creates overlapped small chunks (with 50% overlap). + The probabilities of the overlapped chunks are combined using + hamming windows. + apply_energy_VAD: bool + If True, a energy-based VAD is used on the detected speech segments. + The neural network VAD often creates longer segments and tends to + merge close segments together. The energy VAD post-processes can be + useful for having a fine-grained voice activity detection. + The energy thresholds is managed by activation_th and + deactivation_th (see below). + double_check: bool + If True, double checks (using the neural VAD) that the candidate + speech segments actually contain speech. A threshold on the mean + posterior probabilities provided by the neural network is applied + based on the speech_th parameter (see below). + close_th: float + If the distance between boundaries is smaller than close_th, the + segments will be merged. + len_th: float + If the length of the segment is smaller than close_th, the segments + will be merged. + activation_th: float + Threshold of the neural posteriors above which starting a speech segment. + deactivation_th: float + Threshold of the neural posteriors below which ending a speech segment. + en_activation_th: float + A new speech segment is started it the energy is above activation_th. + This is active only if apply_energy_VAD is True. + en_deactivation_th: float + The segment is considered ended when the energy is <= deactivation_th. + This is active only if apply_energy_VAD is True. + speech_th: float + Threshold on the mean posterior probability within the candidate + speech segment. Below that threshold, the segment is re-assigned to + a non-speech region. This is active only if double_check is True. + + Returns + ------- + boundaries: torch.Tensor + torch.Tensor containing the start second of speech segments in even + positions and their corresponding end in odd positions + (e.g, [1.0, 1.5, 5,.0 6.0] means that we have two speech segment; + one from 1.0 to 1.5 seconds and another from 5.0 to 6.0 seconds). + """ + + # Fetch audio file from web if not local + source, fl = split_path(audio_file) + audio_file = fetch(fl, source=source) + + # Computing speech vs non speech probabilities + prob_chunks = self.get_speech_prob_file( + audio_file, + large_chunk_size=large_chunk_size, + small_chunk_size=small_chunk_size, + overlap_small_chunk=overlap_small_chunk, + ) + + # Apply a threshold to get candidate speech segments + prob_th = self.apply_threshold( + prob_chunks, + activation_th=activation_th, + deactivation_th=deactivation_th, + ).float() + + # Compute the boundaries of the speech segments + boundaries = self.get_boundaries(prob_th, output_value="seconds") + + # Apply energy-based VAD on the detected speech segments + if apply_energy_VAD: + boundaries = self.energy_VAD( + audio_file, + boundaries, + activation_th=en_activation_th, + deactivation_th=en_deactivation_th, + ) + + # Merge short segments + boundaries = self.merge_close_segments(boundaries, close_th=close_th) + + # Remove short segments + boundaries = self.remove_short_segments(boundaries, len_th=len_th) + + # Double check speech segments + if double_check: + boundaries = self.double_check_speech_segments( + boundaries, audio_file, speech_th=speech_th + ) + + return boundaries + + def forward(self, wavs, wav_lens=None): + """Gets frame-level speech-activity predictions""" + return self.get_speech_prob_chunk(wavs, wav_lens) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/__init__.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/__init__.py new file mode 100644 index 000000000..1dbb62c53 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/__init__.py @@ -0,0 +1,17 @@ +"""Importing all the inference interfaces""" + +from . import * # noqa +from .ASR import * # noqa +from .classifiers import * # noqa +from .diarization import * # noqa +from .encoders import * # noqa +from .enhancement import * # noqa +from .interfaces import * # noqa +from .separation import * # noqa +from .SLU import * # noqa +from .speaker import * # noqa +from .ST import * # noqa +from .text import * # noqa +from .TTS import * # noqa +from .VAD import * # noqa +from .vocoders import * # noqa diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/classifiers.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/classifiers.py new file mode 100644 index 000000000..3c8428c31 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/classifiers.py @@ -0,0 +1,322 @@ +"""Specifies the inference interfaces for Audio Classification modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import torch +import torchaudio + +import speechbrain +from speechbrain.dataio import audio_io +from speechbrain.inference.interfaces import Pretrained +from speechbrain.utils.data_utils import split_path +from speechbrain.utils.fetching import LocalStrategy, fetch + + +class EncoderClassifier(Pretrained): + """A ready-to-use class for utterance-level classification (e.g, speaker-id, + language-id, emotion recognition, keyword spotting, etc). + + The class assumes that an encoder called "embedding_model" and a model + called "classifier" are defined in the yaml file. If you want to + convert the predicted index into a corresponding text label, please + provide the path of the label_encoder in a variable called 'lab_encoder_file' + within the yaml. + + The class can be used either to run only the encoder (encode_batch()) to + extract embeddings or to run a classification step (classify_batch()). + + Arguments + --------- + See ``Pretrained`` + + Example + ------- + >>> from speechbrain.dataio import audio_io + >>> from speechbrain.inference.classifiers import EncoderClassifier + >>> # Model is downloaded from the speechbrain HuggingFace repo + >>> tmpdir = getfixture("tmpdir") + >>> classifier = EncoderClassifier.from_hparams( + ... source="speechbrain/spkrec-ecapa-voxceleb", + ... savedir=tmpdir, + ... ) + >>> classifier.hparams.label_encoder.ignore_len() + + >>> # Compute embeddings + >>> signal, fs = audio_io.load("tests/samples/single-mic/example1.wav") + >>> embeddings = classifier.encode_batch(signal) + + >>> # Classification + >>> prediction = classifier.classify_batch(signal) + """ + + MODULES_NEEDED = [ + "compute_features", + "mean_var_norm", + "embedding_model", + "classifier", + ] + + def encode_batch(self, wavs, wav_lens=None, normalize=False): + """Encodes the input audio into a single vector embedding. + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = .normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. Make sure the sample rate is fs=16000 Hz. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + normalize : bool + If True, it normalizes the embeddings with the statistics + contained in mean_var_norm_emb. + + Returns + ------- + torch.Tensor + The encoded batch + """ + # Manage single waveforms in input + if len(wavs.shape) == 1: + wavs = wavs.unsqueeze(0) + + # Assign full length if wav_lens is not assigned + if wav_lens is None: + wav_lens = torch.ones(wavs.shape[0], device=self.device) + + # Storing waveform in the specified device + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + wavs = wavs.float() + + # Computing features and embeddings + feats = self.mods.compute_features(wavs) + feats = self.mods.mean_var_norm(feats, wav_lens) + embeddings = self.mods.embedding_model(feats, wav_lens) + if normalize: + embeddings = self.hparams.mean_var_norm_emb( + embeddings, torch.ones(embeddings.shape[0], device=self.device) + ) + return embeddings + + def classify_batch(self, wavs, wav_lens=None): + """Performs classification on the top of the encoded features. + + It returns the posterior probabilities, the index and, if the label + encoder is specified it also the text label. + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. Make sure the sample rate is fs=16000 Hz. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + out_prob + The log posterior probabilities of each class ([batch, N_class]) + score: + It is the value of the log-posterior for the best class ([batch,]) + index + The indexes of the best class ([batch,]) + text_lab: + List with the text labels corresponding to the indexes. + (label encoder should be provided). + """ + emb = self.encode_batch(wavs, wav_lens) + out_prob = self.mods.classifier(emb).squeeze(1) + score, index = torch.max(out_prob, dim=-1) + text_lab = self.hparams.label_encoder.decode_torch(index) + return out_prob, score, index, text_lab + + def classify_file(self, path, **kwargs): + """Classifies the given audiofile into the given set of labels. + + Arguments + --------- + path : str + Path to audio file to classify. + **kwargs : dict + Arguments forwarded to ``load_audio``. + + Returns + ------- + out_prob : torch.Tensor + The log posterior probabilities of each class ([batch, N_class]) + score : torch.Tensor + It is the value of the log-posterior for the best class ([batch,]) + index : torch.Tensor + The indexes of the best class ([batch,]) + text_lab : list of str + List with the text labels corresponding to the indexes. + (label encoder should be provided). + """ + waveform = self.load_audio(path, **kwargs) + # Fake a batch: + batch = waveform.unsqueeze(0) + rel_length = torch.tensor([1.0]) + emb = self.encode_batch(batch, rel_length) + out_prob = self.mods.classifier(emb).squeeze(1) + score, index = torch.max(out_prob, dim=-1) + text_lab = self.hparams.label_encoder.decode_torch(index) + return out_prob, score, index, text_lab + + def forward(self, wavs, wav_lens=None): + """Runs the classification""" + return self.classify_batch(wavs, wav_lens) + + +class AudioClassifier(Pretrained): + """A ready-to-use class for utterance-level classification (e.g, speaker-id, + language-id, emotion recognition, keyword spotting, etc). + + The class assumes that an encoder called "embedding_model" and a model + called "classifier" are defined in the yaml file. If you want to + convert the predicted index into a corresponding text label, please + provide the path of the label_encoder in a variable called 'lab_encoder_file' + within the yaml. + + The class can be used either to run only the encoder (encode_batch()) to + extract embeddings or to run a classification step (classify_batch()). + + Arguments + --------- + See ``Pretrained``. + + Example + ------- + >>> import torchaudio + >>> from speechbrain.inference.classifiers import AudioClassifier + >>> tmpdir = getfixture("tmpdir") + >>> classifier = AudioClassifier.from_hparams( + ... source="speechbrain/cnn14-esc50", + ... savedir=tmpdir, + ... ) + >>> signal = torch.randn(1, 16000) + >>> prediction, _, _, text_lab = classifier.classify_batch(signal) + >>> print(prediction.shape) + torch.Size([1, 1, 50]) + """ + + def classify_batch(self, wavs, wav_lens=None): + """Performs classification on the top of the encoded features. + + It returns the posterior probabilities, the index and, if the label + encoder is specified it also the text label. + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. Make sure the sample rate is fs=16000 Hz. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + out_prob : torch.Tensor + The log posterior probabilities of each class ([batch, N_class]) + score : torch.Tensor + It is the value of the log-posterior for the best class ([batch,]) + index : torch.Tensor + The indexes of the best class ([batch,]) + text_lab : list of str + List with the text labels corresponding to the indexes. + (label encoder should be provided). + """ + wavs = wavs.to(self.device) + X_stft = self.mods.compute_stft(wavs) + X_stft_power = speechbrain.processing.features.spectral_magnitude( + X_stft, power=self.hparams.spec_mag_power + ) + + if self.hparams.use_melspectra: + net_input = self.mods.compute_fbank(X_stft_power) + else: + net_input = torch.log1p(X_stft_power) + + # Embeddings + sound classifier + embeddings = self.mods.embedding_model(net_input) + if embeddings.ndim == 4: + embeddings = embeddings.mean((-1, -2)) + + out_probs = self.mods.classifier(embeddings) + score, index = torch.max(out_probs, dim=-1) + text_lab = self.hparams.label_encoder.decode_torch(index) + return out_probs, score, index, text_lab + + def classify_file(self, path, savedir=None): + """Classifies the given audiofile into the given set of labels. + + Arguments + --------- + path : str + Path to audio file to classify. + savedir : str + Path to folder for caching downloads. + + Returns + ------- + out_prob + The log posterior probabilities of each class ([batch, N_class]) + score: + It is the value of the log-posterior for the best class ([batch,]) + index + The indexes of the best class ([batch,]) + text_lab: + List with the text labels corresponding to the indexes. + (label encoder should be provided). + """ + source, fl = split_path(path) + path = fetch( + fl, + source=source, + savedir=savedir, + local_strategy=LocalStrategy.SYMLINK, + ) + + batch, fs_file = audio_io.load(path) + batch = batch.to(self.device) + fs_model = self.hparams.sample_rate + + # resample the data if needed + if fs_file != fs_model: + print(f"Resampling the audio from {fs_file} Hz to {fs_model} Hz") + tf = torchaudio.transforms.Resample( + orig_freq=fs_file, new_freq=fs_model + ).to(self.device) + batch = batch.mean(dim=0, keepdim=True) + batch = tf(batch) + + out_probs, score, index, text_lab = self.classify_batch(batch) + return out_probs, score, index, text_lab + + def forward(self, wavs, wav_lens=None): + """Runs the classification""" + return self.classify_batch(wavs, wav_lens) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/diarization.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/diarization.py new file mode 100644 index 000000000..349e7e55c --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/diarization.py @@ -0,0 +1,241 @@ +"""Specifies the inference interfaces for diarization modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import torch + +from speechbrain.inference.interfaces import Pretrained + + +class Speech_Emotion_Diarization(Pretrained): + """A ready-to-use SED interface (audio -> emotions and their durations) + + Arguments + --------- + See ``Pretrained`` + + Example + ------- + >>> from speechbrain.inference.diarization import Speech_Emotion_Diarization + >>> tmpdir = getfixture("tmpdir") + >>> sed_model = Speech_Emotion_Diarization.from_hparams( + ... source="speechbrain/emotion-diarization-wavlm-large", + ... savedir=tmpdir, + ... ) # doctest: +SKIP + >>> sed_model.diarize_file( + ... "speechbrain/emotion-diarization-wavlm-large/example.wav" + ... ) # doctest: +SKIP + """ + + MODULES_NEEDED = ["input_norm", "wav2vec", "output_mlp"] + + def diarize_file(self, path): + """Get emotion diarization of a spoken utterance. + + Arguments + --------- + path : str + Path to audio file which to diarize. + + Returns + ------- + list of dictionary: List[Dict[List]] + The emotions and their temporal boundaries. + """ + waveform = self.load_audio(path) + # Fake a batch: + batch = waveform.unsqueeze(0) + rel_length = torch.tensor([1.0]) + frame_class = self.diarize_batch(batch, rel_length, [path]) + return frame_class + + def encode_batch(self, wavs, wav_lens): + """Encodes audios into fine-grained emotional embeddings + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels]. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + torch.Tensor + The encoded batch + """ + if len(wavs.shape) == 1: + wavs = wavs.unsqueeze(0) + + # Assign full length if wav_lens is not assigned + if wav_lens is None: + wav_lens = torch.ones(wavs.shape[0], device=self.device) + + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + + wavs = self.mods.input_norm(wavs, wav_lens) + outputs = self.mods.wav2vec2(wavs) + return outputs + + def diarize_batch(self, wavs, wav_lens, batch_id): + """Get emotion diarization of a batch of waveforms. + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels]. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + batch_id : torch.Tensor + id of each batch (file names etc.) + + Returns + ------- + list of dictionary: List[Dict[List]] + The emotions and their temporal boundaries. + """ + outputs = self.encode_batch(wavs, wav_lens) + averaged_out = self.hparams.avg_pool(outputs) + outputs = self.mods.output_mlp(averaged_out) + outputs = self.hparams.log_softmax(outputs) + score, index = torch.max(outputs, dim=-1) + preds = self.hparams.label_encoder.decode_torch(index) + results = self.preds_to_diarization(preds, batch_id) + return results + + def preds_to_diarization(self, prediction, batch_id): + """Convert frame-wise predictions into a dictionary of + diarization results. + + Arguments + --------- + prediction : torch.Tensor + Frame-wise predictions + batch_id : str + The id for this batch + + Returns + ------- + dictionary + A dictionary with the start/end of each emotion + """ + results = {} + + for i in range(len(prediction)): + pred = prediction[i] + lol = [] + for j in range(len(pred)): + start = round(self.hparams.stride * 0.02 * j, 2) + end = round(start + self.hparams.window_length * 0.02, 2) + lol.append([batch_id[i], start, end, pred[j]]) + + lol = self.merge_ssegs_same_emotion_adjacent(lol) + results[batch_id[i]] = [ + {"start": k[1], "end": k[2], "emotion": k[3]} for k in lol + ] + return results + + def forward(self, wavs, wav_lens, batch_id): + """Get emotion diarization for a batch of waveforms.""" + return self.diarize_batch(wavs, wav_lens, batch_id) + + def is_overlapped(self, end1, start2): + """Returns True if segments are overlapping. + + Arguments + --------- + end1 : float + End time of the first segment. + start2 : float + Start time of the second segment. + + Returns + ------- + overlapped : bool + True of segments overlapped else False. + + Example + ------- + >>> Speech_Emotion_Diarization.is_overlapped(None, 5.5, 3.4) + True + >>> Speech_Emotion_Diarization.is_overlapped(None, 5.5, 6.4) + False + """ + + return start2 <= end1 + + def merge_ssegs_same_emotion_adjacent(self, lol): + """Merge adjacent sub-segs if they are the same emotion. + + Arguments + --------- + lol : list of list + Each list contains [utt_id, sseg_start, sseg_end, emo_label]. + + Returns + ------- + new_lol : list of list + new_lol contains adjacent segments merged from the same emotion ID. + + Example + ------- + >>> from speechbrain.utils.EDER import merge_ssegs_same_emotion_adjacent + >>> lol = [ + ... ["u1", 0.0, 7.0, "a"], + ... ["u1", 7.0, 9.0, "a"], + ... ["u1", 9.0, 11.0, "n"], + ... ["u1", 11.0, 13.0, "n"], + ... ["u1", 13.0, 15.0, "n"], + ... ["u1", 15.0, 16.0, "a"], + ... ] + >>> merge_ssegs_same_emotion_adjacent(lol) + [['u1', 0.0, 9.0, 'a'], ['u1', 9.0, 15.0, 'n'], ['u1', 15.0, 16.0, 'a']] + """ + new_lol = [] + + # Start from the first sub-seg + sseg = lol[0] + flag = False + for i in range(1, len(lol)): + next_sseg = lol[i] + # IF sub-segments overlap AND has same emotion THEN merge + if ( + self.is_overlapped(sseg[2], next_sseg[1]) + and sseg[3] == next_sseg[3] + ): + sseg[2] = next_sseg[2] # just update the end time + # This is important. For the last sseg, if it is the same emotion then merge + # Make sure we don't append the last segment once more. Hence, set FLAG=True + if i == len(lol) - 1: + flag = True + new_lol.append(sseg) + else: + new_lol.append(sseg) + sseg = next_sseg + # Add last segment only when it was skipped earlier. + if flag is False: + new_lol.append(lol[-1]) + return new_lol diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/encoders.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/encoders.py new file mode 100644 index 000000000..b59838a94 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/encoders.py @@ -0,0 +1,272 @@ +"""Specifies the inference interfaces for speech and audio encoders. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import torch + +from speechbrain.inference.interfaces import Pretrained + + +class WaveformEncoder(Pretrained): + """A ready-to-use waveformEncoder model + + It can be used to wrap different embedding models such as SSL ones (wav2vec2) + or speaker ones (Xvector) etc. Two functions are available: encode_batch and + encode_file. They can be used to obtain the embeddings directly from an audio + file or from a batch of audio tensors respectively. + + The given YAML must contain the fields specified in the *_NEEDED[] lists. + + Arguments + --------- + See ``Pretrained`` + + Example + ------- + >>> from speechbrain.inference.encoders import WaveformEncoder + >>> tmpdir = getfixture("tmpdir") + >>> ssl_model = WaveformEncoder.from_hparams( + ... source="speechbrain/ssl-wav2vec2-base-libri", + ... savedir=tmpdir, + ... ) # doctest: +SKIP + >>> ssl_model.encode_file( + ... "samples/audio_samples/example_fr.wav" + ... ) # doctest: +SKIP + """ + + MODULES_NEEDED = ["encoder"] + + def encode_file(self, path, **kwargs): + """Encode the given audiofile into a sequence of embeddings. + + Arguments + --------- + path : str + Path to audio file which to encode. + **kwargs : dict + Arguments forwarded to ``load_audio`` + + Returns + ------- + torch.Tensor + The audiofile embeddings produced by this system. + """ + waveform = self.load_audio(path, **kwargs) + # Fake a batch: + batch = waveform.unsqueeze(0) + rel_length = torch.tensor([1.0]) + results = self.encode_batch(batch, rel_length) + return results["embeddings"] + + def encode_batch(self, wavs, wav_lens): + """Encodes the input audio into a sequence of hidden states + + The waveforms should already be in the model's desired format. + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + torch.Tensor + The encoded batch + """ + wavs = wavs.float() + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + encoder_out = self.mods.encoder(wavs, wav_lens) + return encoder_out + + def forward(self, wavs, wav_lens): + """Runs the encoder""" + return self.encode_batch(wavs, wav_lens) + + +class MelSpectrogramEncoder(Pretrained): + """A MelSpectrogramEncoder class created for the Zero-Shot Multi-Speaker TTS models. + + This is for speaker encoder models using the PyTorch MelSpectrogram transform for compatibility with the + current TTS pipeline. + + This class can be used to encode a single waveform, a single mel-spectrogram, or a batch of mel-spectrograms. + + Arguments + --------- + See ``Pretrained`` + + Example + ------- + >>> import torchaudio + >>> from speechbrain.inference.encoders import MelSpectrogramEncoder + >>> # Model is downloaded from the speechbrain HuggingFace repo + >>> tmpdir = getfixture("tmpdir") + >>> encoder = MelSpectrogramEncoder.from_hparams( + ... source="speechbrain/tts-ecapa-voxceleb", + ... savedir=tmpdir, + ... ) # doctest: +SKIP + + >>> # Compute embedding from a waveform (sample_rate must match the sample rate of the encoder) + >>> from speechbrain.dataio import audio_io + >>> signal, fs = audio_io.load( + ... "tests/samples/single-mic/example1.wav" + ... ) # doctest: +SKIP + >>> spk_emb = encoder.encode_waveform(signal) # doctest: +SKIP + + >>> # Compute embedding from a mel-spectrogram (sample_rate must match the sample rate of the ecoder) + >>> mel_spec = encoder.mel_spectogram(audio=signal) # doctest: +SKIP + >>> spk_emb = encoder.encode_mel_spectrogram(mel_spec) # doctest: +SKIP + + >>> # Compute embeddings for a batch of mel-spectrograms + >>> spk_embs = encoder.encode_mel_spectrogram_batch( + ... mel_spec + ... ) # doctest: +SKIP + """ + + MODULES_NEEDED = ["normalizer", "embedding_model"] + + def dynamic_range_compression(self, x, C=1, clip_val=1e-5): + """Dynamic range compression for audio signals""" + return torch.log(torch.clamp(x, min=clip_val) * C) + + def mel_spectogram(self, audio): + """calculates MelSpectrogram for a raw audio signal + + Arguments + --------- + audio : torch.tensor + input audio signal + + Returns + ------- + mel : torch.Tensor + Mel-spectrogram + """ + from torchaudio import transforms + + audio_to_mel = transforms.MelSpectrogram( + sample_rate=self.hparams.sample_rate, + hop_length=self.hparams.hop_length, + win_length=self.hparams.win_length, + n_fft=self.hparams.n_fft, + n_mels=self.hparams.n_mel_channels, + f_min=self.hparams.mel_fmin, + f_max=self.hparams.mel_fmax, + power=self.hparams.power, + normalized=self.hparams.mel_normalized, + norm=self.hparams.norm, + mel_scale=self.hparams.mel_scale, + ).to(audio.device) + + mel = audio_to_mel(audio) + + if self.hparams.dynamic_range_compression: + mel = self.dynamic_range_compression(mel) + + return mel + + def encode_waveform(self, wav): + """ + Encodes a single waveform + + Arguments + --------- + + wav : torch.Tensor + waveform + + Returns + ------- + encoder_out : torch.Tensor + Speaker embedding for the input waveform + """ + + # Moves tensor to the appropriate device + wav = wav.to(self.device) + + # Computes mel-spectrogram + mel_spec = self.mel_spectogram(audio=wav) + + # Calls encode_mel_spectrogram to compute the speaker embedding + return self.encode_mel_spectrogram(mel_spec) + + def encode_mel_spectrogram(self, mel_spec): + """ + Encodes a single mel-spectrograms + + Arguments + --------- + + mel_spec : torch.Tensor + Mel-spectrograms + + Returns + ------- + encoder_out : torch.Tensor + Speaker embedding for the input mel-spectrogram + """ + + # Fakes a batch + batch = mel_spec + if len(mel_spec.shape) == 2: + batch = mel_spec.unsqueeze(0) + rel_length = torch.tensor([1.0]) + + # Calls encode_mel_spectrogram_batch to compute speaker embeddings + results = self.encode_mel_spectrogram_batch(batch, rel_length) + + return results + + def encode_mel_spectrogram_batch(self, mel_specs, lens=None): + """ + Encodes a batch of mel-spectrograms + + Arguments + --------- + + mel_specs : torch.Tensor + Mel-spectrograms + lens : torch.Tensor + Relative lengths of the mel-spectrograms + + Returns + ------- + encoder_out : torch.Tensor + Speaker embedding for the input mel-spectrogram batch + """ + + # Assigns full length if lens is not assigned + if lens is None: + lens = torch.ones(mel_specs.shape[0], device=self.device) + + # Moves the tensors to the appropriate device + mel_specs, lens = mel_specs.to(self.device), lens.to(self.device) + + # Computes speaker embeddings + mel_specs = torch.transpose(mel_specs, 1, 2) + feats = self.hparams.normalizer(mel_specs, lens) + encoder_out = self.hparams.embedding_model(feats) + + return encoder_out + + def __forward(self, mel_specs, lens): + """Runs the encoder""" + return self.encode_batch(mel_specs, lens) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/enhancement.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/enhancement.py new file mode 100644 index 000000000..6efe167cb --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/enhancement.py @@ -0,0 +1,373 @@ +"""Specifies the inference interfaces for speech enhancement modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 + * Jonas Rochdi 2025 +""" + +import torch + +from speechbrain.dataio import audio_io +from speechbrain.inference.interfaces import Pretrained +from speechbrain.utils.callchains import lengths_arg_exists + + +def pad_spec(Y, mode="zero_pad"): + """Pad tensor `Y` along axis 3 to 64 with the given algorithm.""" + T = Y.size(3) + if T % 64 != 0: + num_pad = 64 - T % 64 + else: + num_pad = 0 + if mode == "zero_pad": + pad2d = torch.nn.ZeroPad2d((0, num_pad, 0, 0)) + elif mode == "reflection": + pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0, 0)) + elif mode == "replication": + pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0, 0)) + else: + raise NotImplementedError("This function hasn't been implemented yet.") + return pad2d(Y) + + +class SpectralMaskEnhancement(Pretrained): + """A ready-to-use model for speech enhancement. + + Arguments + --------- + See ``Pretrained``. + + Example + ------- + >>> import torch + >>> from speechbrain.inference.enhancement import SpectralMaskEnhancement + >>> # Model is downloaded from the speechbrain HuggingFace repo + >>> tmpdir = getfixture("tmpdir") + >>> enhancer = SpectralMaskEnhancement.from_hparams( + ... source="speechbrain/metricgan-plus-voicebank", + ... savedir=tmpdir, + ... ) + >>> enhanced = enhancer.enhance_file( + ... "speechbrain/metricgan-plus-voicebank/example.wav" + ... ) + """ + + HPARAMS_NEEDED = ["compute_stft", "spectral_magnitude", "resynth"] + MODULES_NEEDED = ["enhance_model"] + + def compute_features(self, wavs): + """Compute the log spectral magnitude features for masking. + + Arguments + --------- + wavs : torch.Tensor + A batch of waveforms to convert to log spectral mags. + + Returns + ------- + feats : torch.Tensor + The log spectral magnitude features. + """ + feats = self.hparams.compute_stft(wavs) + feats = self.hparams.spectral_magnitude(feats) + return torch.log1p(feats) + + def enhance_batch(self, noisy, lengths=None): + """Enhance a batch of noisy waveforms. + + Arguments + --------- + noisy : torch.Tensor + A batch of waveforms to perform enhancement on. + lengths : torch.Tensor + The lengths of the waveforms if the enhancement model handles them. + + Returns + ------- + wavs : torch.Tensor + A batch of enhanced waveforms of the same shape as input. + """ + noisy = noisy.to(self.device) + noisy_features = self.compute_features(noisy) + + # Perform masking-based enhancement, multiplying output with input. + if lengths is not None: + mask = self.mods.enhance_model(noisy_features, lengths=lengths) + else: + mask = self.mods.enhance_model(noisy_features) + enhanced = torch.mul(mask, noisy_features) + + # Return resynthesized waveforms + return self.hparams.resynth(torch.expm1(enhanced), noisy) + + def enhance_file(self, filename, output_filename=None, **kwargs): + """Enhance a wav file. + + Arguments + --------- + filename : str + Location on disk to load file for enhancement. + output_filename : str + If provided, writes enhanced data to this file. + **kwargs : dict + Arguments forwarded to ``load_audio``. + + Returns + ------- + wav : torch.Tensor + The enhanced waveform. + """ + noisy = self.load_audio(filename, **kwargs) + noisy = noisy.to(self.device) + + # Fake a batch: + batch = noisy.unsqueeze(0) + if lengths_arg_exists(self.enhance_batch): + enhanced = self.enhance_batch(batch, lengths=torch.tensor([1.0])) + else: + enhanced = self.enhance_batch(batch) + + if output_filename is not None: + audio_io.save( + path=output_filename, + src=enhanced, + sample_rate=self.hparams.compute_stft.sample_rate, + ) + + return enhanced.squeeze(0) + + +class WaveformEnhancement(Pretrained): + """A ready-to-use model for speech enhancement. + + Arguments + --------- + See ``Pretrained``. + + Example + ------- + >>> from speechbrain.inference.enhancement import WaveformEnhancement + >>> # Model is downloaded from the speechbrain HuggingFace repo + >>> tmpdir = getfixture("tmpdir") + >>> enhancer = WaveformEnhancement.from_hparams( + ... source="speechbrain/mtl-mimic-voicebank", + ... savedir=tmpdir, + ... ) + >>> enhanced = enhancer.enhance_file( + ... "speechbrain/mtl-mimic-voicebank/example.wav" + ... ) + """ + + MODULES_NEEDED = ["enhance_model"] + + def enhance_batch(self, noisy, lengths=None): + """Enhance a batch of noisy waveforms. + + Arguments + --------- + noisy : torch.Tensor + A batch of waveforms to perform enhancement on. + lengths : torch.Tensor + The lengths of the waveforms if the enhancement model handles them. + + Returns + ------- + torch.Tensor + A batch of enhanced waveforms of the same shape as input. + """ + noisy = noisy.to(self.device) + enhanced_wav, _ = self.mods.enhance_model(noisy) + return enhanced_wav + + def enhance_file(self, filename, output_filename=None, **kwargs): + """Enhance a wav file. + + Arguments + --------- + filename : str + Location on disk to load file for enhancement. + output_filename : str + If provided, writes enhanced data to this file. + **kwargs : dict + Arguments forwarded to ``load_audio`` + + Returns + ------- + enhanced : torch.Tensor + The enhanced waveform. + """ + noisy = self.load_audio(filename, **kwargs) + + # Fake a batch: + batch = noisy.unsqueeze(0) + enhanced = self.enhance_batch(batch) + + if output_filename is not None: + audio_io.save( + path=output_filename, + src=enhanced, + sample_rate=self.audio_normalizer.sample_rate, + ) + + return enhanced.squeeze(0) + + def forward(self, noisy, lengths=None): + """Runs enhancement on the noisy input""" + return self.enhance_batch(noisy, lengths) + + +class SGMSEEnhancement(Pretrained): + """Ready-to-use SGMSE speech enhancement. + + Arguments + --------- + See ``Pretrained``. + + Example + ------- + >>> from speechbrain.inference.enhancement import SGMSEEnhancement + >>> tmpdir = getfixture("tmpdir") + >>> enh = SGMSEEnhancement.from_hparams( + ... source="speechbrain/sgmse-voicebank", savedir=tmpdir + ... ) # doctest: +SKIP + >>> out = enh.enhance_file( + ... "speechbrain/sgmse-voicebank/example.wav" + ... ) # doctest: +SKIP + """ + + MODULES_NEEDED = ["score_model"] + HPARAMS_NEEDED = [ + "sample_rate", + "n_fft", + "hop_length", + "window_type", + "transform_type", + "spec_factor", + "sampling", + ] + + def _ensure_stft_setup(self): + if getattr(self, "_stft_ready", False): + return + n_fft = self.hparams.n_fft + self._window = self._get_window(self.hparams.window_type, n_fft).to( + self.device + ) + self._stft_kwargs = dict( + n_fft=n_fft, + hop_length=self.hparams.hop_length, + center=True, + return_complex=True, + ) + self._stft_ready = True + + def enhance_batch(self, noisy, lengths=None): + """Enhance a batch of noisy waveforms (B, T) → (B, T).""" + self._ensure_stft_setup() + + noisy = noisy.to(self.device) + # scale to [-1,1] by max abs per item (like the Brain inference) + norms = torch.clamp(noisy.abs().amax(dim=1, keepdim=True), min=1e-8) + y = noisy / norms + + # STFT + forward spec transform + channel dim + Y = self._spec_fwd(self._stft(y)).unsqueeze(1) # (B,1,F,T) + F_orig, T_orig_spec = Y.shape[-2:] + + # pad for U-Net constraints + Yp = pad_spec(Y, mode="reflection") + + # Call the SGMSE sampler on spectrograms + smp = self.hparams.sampling + x_hat = self.mods.score_model.enhance( + Yp, + sampler_type=smp.get("sampler_type", "pc"), + predictor=smp.get("predictor", "reverse_diffusion"), + corrector=smp.get("corrector", "ald"), + N=smp.get("N", 30), + corrector_steps=smp.get("corrector_steps", 1), + snr=smp.get("snr", 0.5), + ) # (B,1,F,T) + + # Trim padding, drop channel, inverse spec transform, iSTFT + Xh = x_hat[:, :, :F_orig, :T_orig_spec].squeeze(1) # (B,F,T) + Xh = self._spec_back(Xh) + enh = self._istft(Xh, length=y.size(1)) * norms # (B,T) + return enh + + def enhance_file(self, filename, output_filename=None, **kwargs): + """Enhance a wav file; optionally write to disk.""" + noisy = self.load_audio(filename, **kwargs).to(self.device) + enhanced = self.enhance_batch(noisy.unsqueeze(0)).squeeze(0) + + if output_filename is not None: + audio_io.save( + output_filename, + src=enhanced.unsqueeze(0).cpu(), + sample_rate=self.hparams.sample_rate, + ) + return enhanced + + def forward(self, noisy, lengths=None): + """Alias to enable nn.Module-style calls.""" + return self.enhance_batch(noisy, lengths) + + # HELPERS + def _stft(self, sig): + return torch.stft(sig, **{**self._stft_kwargs, "window": self._window}) + + def _istft(self, spec, length=None): + kw = dict(self._stft_kwargs) + kw.pop("return_complex", None) + kw["window"] = self._window + kw["length"] = length + return torch.istft(spec, **kw) + + def _spec_fwd(self, S): + ttype = self.hparams.transform_type + factor = self.hparams.spec_factor + e = getattr(self.hparams, "spec_abs_exponent", 1.0) + + if ttype == "exponent": + if e != 1.0: + mag, ph = S.abs() ** e, S.angle() + S = mag * torch.exp(1j * ph) + S = S * factor + elif ttype == "log": + mag, ph = torch.log1p(S.abs()), S.angle() + S = mag * torch.exp(1j * ph) + S = S * factor + return S + + def _spec_back(self, S): + ttype = self.hparams.transform_type + factor = self.hparams.spec_factor + e = getattr(self.hparams, "spec_abs_exponent", 1.0) + + if ttype == "exponent": + S = S / factor + if e != 1.0: + mag, ph = S.abs() ** (1.0 / e), S.angle() + S = mag * torch.exp(1j * ph) + elif ttype == "log": + S = S / factor + mag, ph = torch.expm1(S.abs()), S.angle() + S = mag * torch.exp(1j * ph) + return S + + def _get_window(self, window_type, n_fft): + if window_type == "sqrthann": + return torch.sqrt(torch.hann_window(n_fft, periodic=True)) + elif window_type == "hann": + return torch.hann_window(n_fft, periodic=True) + raise NotImplementedError(f"Window type {window_type} not implemented!") diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/interfaces.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/interfaces.py new file mode 100644 index 000000000..4b74c74ed --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/interfaces.py @@ -0,0 +1,694 @@ +"""Defines interfaces for simple inference with pretrained models + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import sys +import warnings +from types import SimpleNamespace + +import torch +from hyperpyyaml import load_hyperpyyaml +from torch.nn import ( + DataParallel as DP, + SyncBatchNorm, +) +from torch.nn.parallel import DistributedDataParallel as DDP + +from speechbrain.dataio import audio_io +from speechbrain.dataio.batch import PaddedBatch, PaddedData +from speechbrain.dataio.preprocess import AudioNormalizer +from speechbrain.utils.autocast import AMPConfig, TorchAutocast +from speechbrain.utils.data_pipeline import DataPipeline +from speechbrain.utils.data_utils import split_path +from speechbrain.utils.distributed import infer_device +from speechbrain.utils.fetching import FetchConfig, LocalStrategy, fetch +from speechbrain.utils.logger import get_logger +from speechbrain.utils.run_opts import RunOptions +from speechbrain.utils.superpowers import import_from_path + +logger = get_logger(__name__) + + +def foreign_class( + source, + hparams_file="hyperparams.yaml", + pymodule_file="custom.py", + classname="CustomInterface", + savedir=None, + local_strategy: LocalStrategy = LocalStrategy.SYMLINK, + fetch_config: FetchConfig = FetchConfig(), + **kwargs, +): + """Thin wrapper for `pretrained_from_hparams()` that fetches and loads a custom class. + + The pymodule file should contain a class with the given classname. An + instance of that class is returned. The idea is to have a custom Pretrained + subclass in the file. The pymodule file is also added to the python path + before the Hyperparams YAML file is loaded, so it can contain any custom + implementations that are needed. + + .. warning:: + Caution should be used with this function as it can download and run + arbitrary code onto the machine this function is used on. Only use + this function when the target module is from a highly trusted source! + + Arguments + --------- + source : str or Path or FetchSource + The location to use for finding the model. See + ``speechbrain.utils.fetching.fetch`` for details. + hparams_file : str + The name of the hyperparameters file to use for constructing + the modules necessary for inference. Must contain two keys: + "modules" and "pretrainer", as described in `pretrained_from_hparams`. + pymodule_file : str + The name of the Python file containing the model's python class. The file + will be fetched from `source` and will be used to load the class code. + classname : str + The name of the model's Python class, which should be present in the + code of the `pymodule_file`. + savedir : Optional[Union[str, Path]] + Where to put the pretraining material. If not given, just use cache. + local_strategy : LocalStrategy, default LocalStrategy.SYMLINK + Type of caching to use for keeping a local copy. + fetch_config : FetchConfig + Configuration options for caching and other fetch behavior. + **kwargs + Arguments to pass to `pretrained_from_hparams` + + Returns + ------- + object + An instance of a class with the given classname from the given pymodule file. + """ + pymodule_local_path = fetch( + filename=pymodule_file, + source=source, + savedir=savedir, + save_filename=None, + local_strategy=local_strategy, + fetch_config=fetch_config, + ) + sys.path.append(str(pymodule_local_path.parent)) + + # Dynamically import the specified Python module and retrieve the class by name. + # This allows users to define custom model interfaces outside of SpeechBrain. + # After importing, passes the class (not an instance) to pretrained_from_hparams, + # which will handle loading and instantiation with the appropriate hyperparameters. + module = import_from_path(pymodule_local_path) + cls = getattr(module, classname) + return pretrained_from_hparams( + cls=cls, + source=source, + hparams_file=hparams_file, + savedir=savedir, + local_strategy=local_strategy, + fetch_config=fetch_config, + **kwargs, + ) + + +def pretrained_from_hparams( + cls, + source, + hparams_file="hyperparams.yaml", + overrides={}, + overrides_must_match=True, + savedir=None, + download_only=False, + local_strategy: LocalStrategy = LocalStrategy.SYMLINK, + fetch_config: FetchConfig = FetchConfig(), + **kwargs, +): + """Fetch and load an interface from an outside source + + The source can be a location on the filesystem or online/huggingface + + The hyperparams file should contain a "modules" key, which is a + dictionary of torch modules used for computation. + + The hyperparams file should contain a "pretrainer" key, which is a + speechbrain.utils.parameter_transfer.Pretrainer + + .. warning:: + Caution should be used with this function as it can download and run + arbitrary code onto the machine this function is used on. Only use + this function when the target hparams file is from a highly trusted source! + + Arguments + --------- + cls : Type[Pretrained] + The class to construct an instance of, usually a sub-type of Pretrained + source : str or Path or FetchSource + The location to use for finding the model. See + ``speechbrain.utils.fetching.fetch`` for details. + hparams_file : str + The name of the hyperparameters file to use for constructing + the modules necessary for inference. Must contain two keys: + "modules" and "pretrainer", as described. + overrides : dict + Any changes to make to the hparams file when it is loaded. + overrides_must_match : bool + Whether an error will be thrown when an override does not match + a corresponding key in the yaml_stream. + savedir : str or Path + Where to put the pretraining material. If not given, just use cache. + download_only : bool (default: False) + If true, class and instance creation is skipped. + local_strategy : LocalStrategy, default LocalStrategy.SYMLINK + Type of caching to use for keeping a local copy. + fetch_config : FetchConfig + Configuration options for caching and other fetch behavior. + **kwargs : dict + Arguments to forward to class constructor. + + Returns + ------- + object : Optional[Pretrained] + An instance of a Pretrained class, constructed from the hparams. + None is returned if the argument `download_only` is `True`. + """ + hparams_local_path = fetch( + filename=hparams_file, + source=source, + savedir=savedir, + save_filename=None, + local_strategy=local_strategy, + fetch_config=fetch_config, + ) + + # Load the modules: + with open(hparams_local_path, encoding="utf-8") as fin: + hparams = load_hyperpyyaml(fin, overrides, overrides_must_match) + + hparams["savedir"] = savedir + # Pretraining: + pretrainer = hparams["pretrainer"] + pretrainer.set_collect_in(savedir) + pretrainer.collect_files( + default_source=source, + local_strategy=local_strategy, + fetch_config=fetch_config, + ) + # Load on the CPU. Later the params can be moved elsewhere by specifying + if not download_only: + # run_opts={"device": ...} + pretrainer.load_collected() + return cls(modules=hparams["modules"], hparams=hparams, **kwargs) + + # Not strictly necessary, but let's be explicit here + else: + return None + + +class Pretrained(torch.nn.Module): + """Takes a trained model and makes predictions on new data. + + This is a base class which handles some common boilerplate. + It intentionally has an interface similar to ``Brain`` - these base + classes handle similar things. + + Subclasses of Pretrained should implement the actual logic of how + the pretrained system runs, and add methods with descriptive names + (e.g. transcribe_file() for ASR). + + Pretrained is a torch.nn.Module so that methods like .to() or .eval() can + work. Subclasses should provide a suitable forward() implementation: by + convention, it should be a method that takes a batch of audio signals and + runs the full model (as applicable). + + Arguments + --------- + modules : dict of str:torch.nn.Module pairs + The Torch modules that make up the learned system. These can be treated + in special ways (put on the right device, frozen, etc.). These are available + as attributes under ``self.mods``, like self.mods.model(x) + hparams : dict + Each key:value pair should consist of a string key and a hyperparameter + that is used within the overridden methods. These will + be accessible via an ``hparams`` attribute, using "dot" notation: + e.g., self.hparams.model(x). + run_opts : Optional[Union[RunOptions, dict]] + A set of options to change the runtime environment, see ``RunOptions`` for + a complete list. Some options are meant for training, and will not apply + for this instance intended for inference. + freeze_params : bool + To freeze (requires_grad=False) parameters or not. Normally in inference + you want to freeze the params. Also calls .eval() on all modules. + """ + + HPARAMS_NEEDED = [] + MODULES_NEEDED = [] + + def __init__( + self, modules=None, hparams=None, run_opts=None, freeze_params=True + ): + super().__init__() + + # Check which options have been overridden. Order of priority + # is lowest: default < hparams < run_opts: highest + if isinstance(run_opts, dict): + run_opts = RunOptions.from_dictionary(run_opts) + self.run_opt_defaults = RunOptions() + for arg, default in self.run_opt_defaults.as_dict().items(): + if run_opts is not None and arg in run_opts.overridden_args: + setattr(self, arg, run_opts[arg]) + + # If any arg from run_opt_defaults exist in hparams and + # not in command line args "run_opts" + elif hparams is not None and arg in hparams: + setattr(self, arg, hparams[arg]) + else: + setattr(self, arg, default) + + # If device was not provided, make a best guess + if self.device is None: + self.device = infer_device() + + # Set device type based on device string + if self.device == "cpu": + self.device_type = "cpu" + elif "cuda" in self.device: + self.device_type = "cuda" + # Set cuda device based on device string + try: + _, device_index = self.device.split(":") + torch.cuda.set_device(int(device_index)) + except (ValueError, IndexError, TypeError) as e: + logger.warning( + f"Could not parse CUDA device string '{self.device}': {e}. Falling back to device 0." + ) + torch.cuda.set_device(0) + + precision_dtype = AMPConfig.from_name(self.precision).dtype + self.inference_ctx = TorchAutocast( + device_type=self.device_type, dtype=precision_dtype + ) + + # Put modules on the right device, accessible with dot notation + self.mods = torch.nn.ModuleDict(modules) + for module in self.mods.values(): + if module is not None: + module.to(self.device) + + # Check MODULES_NEEDED and HPARAMS_NEEDED and + # make hyperparams available with dot notation + if self.HPARAMS_NEEDED and hparams is None: + raise ValueError("Need to provide hparams dict.") + if hparams is not None: + # Also first check that all required params are found: + for hp in self.HPARAMS_NEEDED: + if hp not in hparams: + raise ValueError(f"Need hparams['{hp}']") + self.hparams = SimpleNamespace(**hparams) + + # Prepare modules for computation, e.g. jit + self._prepare_modules(freeze_params) + + # Audio normalization + self.audio_normalizer = hparams.get( + "audio_normalizer", AudioNormalizer() + ) + + def _prepare_modules(self, freeze_params): + """Prepare modules for computation, e.g. jit. + + Arguments + --------- + freeze_params : bool + Whether to freeze the parameters and call ``eval()``. + """ + + # Make jit-able + self._compile() + self._wrap_distributed() + + # If we don't want to backprop, freeze the pretrained parameters + if freeze_params: + self.mods.eval() + for p in self.mods.parameters(): + p.requires_grad = False + + def load_audio(self, path, savedir=None): + """Load an audio file with this model's input spec + + When using a speech model, it is important to use the same type of data, + as was used to train the model. This means for example using the same + sampling rate and number of channels. It is, however, possible to + convert a file from a higher sampling rate to a lower one (downsampling). + Similarly, it is simple to downmix a stereo file to mono. + The path can be a local path, a web url, or a link to a huggingface repo. + """ + source, fl = split_path(path) + path = fetch(fl, source=source, savedir=savedir) + signal, sr = audio_io.load(str(path), channels_first=False) + signal = signal.to(self.device) + return self.audio_normalizer(signal, sr) + + def _compile(self): + """Compile requested modules with either JIT or TorchInductor.""" + compile_available = hasattr(torch, "compile") + + if not compile_available and self.compile_module_keys is not None: + raise ValueError( + "'compile_module_keys' specified, but this install of PyTorch " + "seems to be too old to support it." + ) + + # Modules to compile with torch.compile + compile_module_keys = set() + if self.compile: + if self.compile_module_keys is None: + compile_module_keys = set(self.mods) + else: + compile_module_keys = set(self.compile_module_keys) + logger.warning( + "--compile and --compile_module_keys are both specified. " + "Only modules specified in --compile_module_keys will be compiled." + ) + + # Modules to compile with jit + jit_module_keys = set() + if self.jit: + if self.jit_module_keys is None: + jit_module_keys = set(self.mods) + else: + jit_module_keys = set(self.jit_module_keys) + logger.warning( + "--jit and --jit_module_keys are both specified. " + "Only modules specified in --jit_module_keys will be compiled." + ) + + # find missing keys + for name in compile_module_keys | jit_module_keys: + if name not in self.mods: + raise ValueError( + f"module {name} is not defined in your hparams file." + ) + + # try 'torch.compile', remove successful compiles from JIT list + for name in compile_module_keys: + try: + module = torch.compile( + self.mods[name], + mode=self.compile_mode, + fullgraph=self.compile_using_fullgraph, + dynamic=self.compile_using_dynamic_shape_tracing, + ) + except Exception as e: + logger.warning( + f"'{name}' in 'compile_module_keys' failed to compile " + f"and will be skipped (may fallback onto JIT, if " + f"specified): {e}" + ) + continue + + self.mods[name] = module.to(self.device) + jit_module_keys.discard(name) + + for name in jit_module_keys: + module = torch.jit.script(self.mods[name]) + self.mods[name] = module.to(self.device) + + def _compile_jit(self): + warnings.warn("'_compile_jit' is deprecated; use '_compile' instead") + self._compile() + + def _wrap_distributed(self): + """Wrap modules with distributed wrapper when requested.""" + if not self.distributed_launch and not self.data_parallel_backend: + return + elif self.distributed_launch: + for name, module in self.mods.items(): + if any(p.requires_grad for p in module.parameters()): + # for ddp, all module must run on same GPU + module = SyncBatchNorm.convert_sync_batchnorm(module) + module = DDP(module, device_ids=[self.device]) + self.mods[name] = module + else: + # data_parallel_backend + for name, module in self.mods.items(): + if any(p.requires_grad for p in module.parameters()): + # if distributed_count = -1 then use all gpus + # otherwise, specify the set of gpu to use + if self.data_parallel_count == -1: + module = DP(module) + else: + module = DP( + module, [i for i in range(self.data_parallel_count)] + ) + self.mods[name] = module + + @classmethod + def from_hparams(cls, source, hparams_file="hyperparams.yaml", **kwargs): + """Fetch and load based from outside source based on HyperPyYAML file + + The source can be a location on the filesystem or online/huggingface + + The hyperparams file should contain a "modules" key, which is a + dictionary of torch modules used for computation. + + The hyperparams file should contain a "pretrainer" key, which is a + speechbrain.utils.parameter_transfer.Pretrainer + + .. warning:: + Caution should be used with this function as it can download and run + arbitrary code onto the machine this function is used on. Only use + this function when the target hparams file is from a highly trusted source! + + Arguments + --------- + source : str + The location to use for finding the model. See + ``speechbrain.utils.fetching.fetch`` for details. + hparams_file : str + The name of the hyperparameters file to use for constructing + the modules necessary for inference. Must contain two keys: + "modules" and "pretrainer", as described. + **kwargs : dict + Arguments to forward to `pretrained_from_hparams`. + + Returns + ------- + Instance of cls + """ + return pretrained_from_hparams( + cls=cls, source=source, hparams_file=hparams_file, **kwargs + ) + + +class EncodeDecodePipelineMixin: + """ + A mixin for pretrained models that makes it possible to specify an encoding pipeline and a decoding pipeline + """ + + def create_pipelines(self): + """ + Initializes the encode and decode pipeline + """ + self._run_init_steps(self.hparams.encode_pipeline) + self._run_init_steps(self.hparams.decode_pipeline) + self.encode_pipeline = DataPipeline( + static_data_keys=self.INPUT_STATIC_KEYS, + dynamic_items=self.hparams.encode_pipeline["steps"], + output_keys=self.hparams.encode_pipeline["output_keys"], + ) + self.decode_pipeline = DataPipeline( + static_data_keys=self.hparams.model_output_keys, + dynamic_items=self.hparams.decode_pipeline["steps"], + output_keys=self.OUTPUT_KEYS, + ) + + def _run_init_steps(self, pipeline_definition): + """Encode/decode pipelines may include initialization + steps, such as filling text encoders with tokens. Calling + this method will run them, if defined""" + steps = pipeline_definition.get("init", []) + for step in steps: + step_func = step.get("func") + if not step_func or not callable(step_func): + raise ValueError("Invalid pipeline init definition") + step_func() + + def _run_pipeline(self, pipeline, input, batch): + if batch: + output = pipeline(input) + else: + output = [pipeline(item) for item in input] + return output + + def _get_encode_pipeline_input(self, input): + return input if self.batch_inputs else self._itemize(input) + + def _get_decode_pipeline_input(self, model_output): + model_output_keys = getattr(self.hparams, "model_output_keys", None) + pipeline_input = model_output + if len(model_output_keys) == 1: + pipeline_input = (pipeline_input,) + # The input to a pipeline is a dictionary. If model_output_keys + # is provided, the output of the model is assumed to be a collection + # (e.g. a list or a tuple). + if model_output_keys: + pipeline_input = dict(zip(model_output_keys, pipeline_input)) + + # By default, the pipeline will be applied to in batch mode + # to the entire model input + if not self.batch_outputs: + pipeline_input = self._itemize(pipeline_input) + return pipeline_input + + def _itemize(self, pipeline_input): + first_item = next(iter(pipeline_input.values())) + keys, values = pipeline_input.keys(), pipeline_input.values() + batch_length = len(first_item) + return [ + dict(zip(keys, [value[idx] for value in values])) + for idx in range(batch_length) + ] + + def to_dict(self, data): + """ + Converts padded batches to dictionaries, leaves + other data types as is + + Arguments + --------- + data: object + a dictionary or a padded batch + + Returns + ------- + results: dict + the dictionary + """ + if isinstance(data, PaddedBatch): + data = { + key: self._get_value(data, key) + for key in self.hparams.encode_pipeline["output_keys"] + } + return data + + def _get_value(self, data, key): + """ + Retrieves the value associated with the specified key, dereferencing + .data where applicable + + Arguments + --------- + data: PaddedBatch + a padded batch + key: str + the key + + Returns + ------- + result: object + the result + """ + value = getattr(data, key) + if not self.input_use_padded_data and isinstance(value, PaddedData): + value = value.data + return value + + @property + def batch_inputs(self): + """ + Determines whether the input pipeline + operates on batches or individual examples + (true means batched) + + Returns + ------- + batch_inputs: bool + """ + return self.hparams.encode_pipeline.get("batch", True) + + @property + def input_use_padded_data(self): + """ + If turned on, raw PaddedData instances will be passed to + the model. If turned off, only .data will be used + + Returns + ------- + result: bool + whether padded data is used as is + """ + return self.hparams.encode_pipeline.get("use_padded_data", False) + + @property + def batch_outputs(self): + """ + Determines whether the output pipeline + operates on batches or individual examples + (true means batched) + + Returns + ------- + batch_outputs: bool + """ + return self.hparams.decode_pipeline.get("batch", True) + + def _collate(self, data): + if not self.batch_inputs: + collate_fn = getattr(self.hparams, "collate_fn", PaddedBatch) + data = collate_fn(data) + return data + + def encode_input(self, input): + """ + Encodes the inputs using the pipeline + + Arguments + --------- + input: dict + the raw inputs + + Returns + ------- + results: object + + """ + pipeline_input = self._get_encode_pipeline_input(input) + model_input = self._run_pipeline( + pipeline=self.encode_pipeline, + input=pipeline_input, + batch=self.batch_inputs, + ) + model_input = self._collate(model_input) + if hasattr(model_input, "to"): + model_input = model_input.to(self.device) + return self.to_dict(model_input) + + def decode_output(self, output): + """ + Decodes the raw model outputs + + Arguments + --------- + output: tuple + raw model outputs + + Returns + ------- + result: dict or list + the output of the pipeline + """ + pipeline_input = self._get_decode_pipeline_input(output) + return self._run_pipeline( + pipeline=self.decode_pipeline, + input=pipeline_input, + batch=self.batch_outputs, + ) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/interpretability.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/interpretability.py new file mode 100644 index 000000000..9dd51e7ef --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/interpretability.py @@ -0,0 +1,182 @@ +"""Specifies the inference interfaces for interpretability modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import torch +import torch.nn.functional as F +import torchaudio + +import speechbrain +from speechbrain.dataio import audio_io +from speechbrain.inference.interfaces import Pretrained +from speechbrain.processing.NMF import spectral_phase +from speechbrain.utils.data_utils import split_path +from speechbrain.utils.fetching import LocalStrategy, fetch + + +class PIQAudioInterpreter(Pretrained): + """ + This class implements the interface for the PIQ posthoc interpreter for an audio classifier. + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> from speechbrain.inference.interpretability import PIQAudioInterpreter + >>> tmpdir = getfixture("tmpdir") + >>> interpreter = PIQAudioInterpreter.from_hparams( + ... source="speechbrain/PIQ-ESC50", + ... savedir=tmpdir, + ... ) + >>> signal = torch.randn(1, 16000) + >>> interpretation, _ = interpreter.interpret_batch(signal) + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def preprocess(self, wavs): + """Pre-process wavs to calculate STFTs""" + X_stft = self.mods.compute_stft(wavs) + X_stft_power = speechbrain.processing.features.spectral_magnitude( + X_stft, power=self.hparams.spec_mag_power + ) + X_stft_logpower = torch.log1p(X_stft_power) + + return X_stft_logpower, X_stft, X_stft_power + + def classifier_forward(self, X_stft_logpower): + """the forward pass for the classifier""" + hcat = self.mods.embedding_model(X_stft_logpower) + embeddings = hcat.mean((-1, -2)) + predictions = self.mods.classifier(embeddings).squeeze(1) + class_pred = predictions.argmax(1) + return hcat, embeddings, predictions, class_pred + + def invert_stft_with_phase(self, X_int, X_stft_phase): + """Inverts STFT spectra given phase.""" + X_stft_phase_sb = torch.cat( + ( + torch.cos(X_stft_phase).unsqueeze(-1), + torch.sin(X_stft_phase).unsqueeze(-1), + ), + dim=-1, + ) + + X_stft_phase_sb = X_stft_phase_sb[:, : X_int.shape[1], :, :] + if X_int.ndim == 3: + X_int = X_int.unsqueeze(-1) + X_wpsb = X_int * X_stft_phase_sb + x_int_sb = self.mods.compute_istft(X_wpsb) + return x_int_sb + + def interpret_batch(self, wavs): + """Classifies the given audio into the given set of labels. + It also provides the interpretation in the audio domain. + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. Make sure the sample rate is fs=16000 Hz. + + Returns + ------- + x_int_sound_domain : torch.Tensor + The interpretation in the waveform domain + text_lab : str + The text label for the classification + """ + wavs = wavs.to(self.device) + X_stft_logpower, X_stft, X_stft_power = self.preprocess(wavs) + X_stft_phase = spectral_phase(X_stft) + + # Embeddings + sound classifier + hcat, embeddings, predictions, class_pred = self.classifier_forward( + X_stft_logpower + ) + + if self.hparams.use_vq: + xhat, hcat, z_q_x = self.mods.psi(hcat, class_pred) + else: + xhat = self.mods.psi.decoder(hcat) + xhat = xhat.squeeze(1) + Tmax = xhat.shape[1] + if self.hparams.use_mask_output: + xhat = F.sigmoid(xhat) + X_int = xhat * X_stft_logpower[:, :Tmax, :] + else: + xhat = F.softplus(xhat) + th = xhat.max() * self.hparams.mask_th + X_int = (xhat > th) * X_stft_logpower[:, :Tmax, :] + X_int = torch.expm1(X_int) + x_int_sound_domain = self.invert_stft_with_phase(X_int, X_stft_phase) + text_lab = self.hparams.label_encoder.decode_torch( + class_pred.unsqueeze(0) + ) + + return x_int_sound_domain, text_lab + + def interpret_file(self, path, savedir=None): + """Classifies the given audiofile into the given set of labels. + It also provides the interpretation in the audio domain. + + Arguments + --------- + path : str + Path to audio file to classify. + savedir : str + Path to cache directory. + + Returns + ------- + x_int_sound_domain : torch.Tensor + The interpretation in the waveform domain + text_lab : str + The text label for the classification + fs_model : int + The sampling frequency of the model. Useful to save the audio. + """ + source, fl = split_path(path) + path = fetch( + fl, + source=source, + savedir=savedir, + local_strategy=LocalStrategy.SYMLINK, + ) + + batch, fs_file = audio_io.load(path) + batch = batch.to(self.device) + fs_model = self.hparams.sample_rate + + # resample the data if needed + if fs_file != fs_model: + print(f"Resampling the audio from {fs_file} Hz to {fs_model} Hz") + tf = torchaudio.transforms.Resample( + orig_freq=fs_file, new_freq=fs_model + ).to(self.device) + batch = batch.mean(dim=0, keepdim=True) + batch = tf(batch) + + x_int_sound_domain, text_lab = self.interpret_batch(batch) + return x_int_sound_domain, text_lab, fs_model + + def forward(self, wavs, wav_lens=None): + """Runs the classification""" + return self.interpret_batch(wavs, wav_lens) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/metrics.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/metrics.py new file mode 100644 index 000000000..b397cfced --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/metrics.py @@ -0,0 +1,97 @@ +"""Specifies the inference interfaces for metric estimation modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import torch + +from speechbrain.inference.interfaces import Pretrained + + +class SNREstimator(Pretrained): + """A "ready-to-use" SNR estimator.""" + + MODULES_NEEDED = ["encoder", "encoder_out"] + HPARAMS_NEEDED = ["stat_pooling", "snrmax", "snrmin"] + + def estimate_batch(self, mix, predictions): + """Run SI-SNR estimation on the estimated sources, and mixture. + + Arguments + --------- + mix : torch.Tensor + The mixture of sources of shape B X T + predictions : torch.Tensor + of size (B x T x C), + where B is batch size + T is number of time points + C is number of sources + + Returns + ------- + tensor + Estimate of SNR + """ + + predictions = predictions.permute(0, 2, 1) + predictions = predictions.reshape(-1, predictions.size(-1)) + + if hasattr(self.hparams, "separation_norm_type"): + if self.hparams.separation_norm_type == "max": + predictions = ( + predictions / predictions.max(dim=1, keepdim=True)[0] + ) + mix = mix / mix.max(dim=1, keepdim=True)[0] + + elif self.hparams.separation_norm_type == "stnorm": + predictions = ( + predictions - predictions.mean(dim=1, keepdim=True) + ) / predictions.std(dim=1, keepdim=True) + mix = (mix - mix.mean(dim=1, keepdim=True)) / mix.std( + dim=1, keepdim=True + ) + + min_T = min(predictions.shape[1], mix.shape[1]) + assert predictions.shape[1] == mix.shape[1], "lengths change" + + mix_repeat = mix.repeat(2, 1) + inp_cat = torch.cat( + [ + predictions[:, :min_T].unsqueeze(1), + mix_repeat[:, :min_T].unsqueeze(1), + ], + dim=1, + ) + + enc = self.mods.encoder(inp_cat) + enc = enc.permute(0, 2, 1) + enc_stats = self.hparams.stat_pooling(enc) + + # this gets the SI-SNR estimate in the compressed range 0-1 + snrhat = self.mods.encoder_out(enc_stats).squeeze() + + # get the SI-SNR estimate in the true range + snrhat = self.gettrue_snrrange(snrhat) + return snrhat + + def forward(self, mix, predictions): + """Just run the batch estimate""" + return self.estimate_batch(mix, predictions) + + def gettrue_snrrange(self, inp): + """Convert from 0-1 range to true snr range""" + range = self.hparams.snrmax - self.hparams.snrmin + inp = inp * range + inp = inp + self.hparams.snrmin + return inp diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/separation.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/separation.py new file mode 100644 index 000000000..4ee10609c --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/separation.py @@ -0,0 +1,129 @@ +"""Specifies the inference interfaces for speech separation modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import torch +import torch.nn.functional as F +import torchaudio + +from speechbrain.dataio import audio_io +from speechbrain.inference.interfaces import Pretrained +from speechbrain.utils.data_utils import split_path +from speechbrain.utils.fetching import LocalStrategy, fetch + + +class SepformerSeparation(Pretrained): + """A "ready-to-use" speech separation model. + + Uses Sepformer architecture. + + Example + ------- + >>> tmpdir = getfixture("tmpdir") + >>> model = SepformerSeparation.from_hparams( + ... source="speechbrain/sepformer-wsj02mix", savedir=tmpdir + ... ) + >>> mix = torch.randn(1, 400) + >>> est_sources = model.separate_batch(mix) + >>> print(est_sources.shape) + torch.Size([1, 400, 2]) + """ + + MODULES_NEEDED = ["encoder", "masknet", "decoder"] + + def separate_batch(self, mix): + """Run source separation on batch of audio. + + Arguments + --------- + mix : torch.Tensor + The mixture of sources. + + Returns + ------- + tensor + Separated sources + """ + + # Separation + mix = mix.to(self.device) + mix_w = self.mods.encoder(mix) + est_mask = self.mods.masknet(mix_w) + mix_w = torch.stack([mix_w] * self.hparams.num_spks) + sep_h = mix_w * est_mask + + # Decoding + est_source = torch.cat( + [ + self.mods.decoder(sep_h[i]).unsqueeze(-1) + for i in range(self.hparams.num_spks) + ], + dim=-1, + ) + + # T changed after conv1d in encoder, fix it here + T_origin = mix.size(1) + T_est = est_source.size(1) + if T_origin > T_est: + est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est)) + else: + est_source = est_source[:, :T_origin, :] + return est_source + + def separate_file(self, path, savedir=None): + """Separate sources from file. + + Arguments + --------- + path : str + Path to file which has a mixture of sources. It can be a local + path, a web url, or a huggingface repo. + savedir : path + Path where to store the wav signals (when downloaded from the web). + Returns + ------- + tensor + Separated sources + """ + source, fl = split_path(path) + path = fetch( + fl, + source=source, + savedir=savedir, + local_strategy=LocalStrategy.SYMLINK, + ) + + batch, fs_file = audio_io.load(path) + batch = batch.to(self.device) + fs_model = self.hparams.sample_rate + + # resample the data if needed + if fs_file != fs_model: + print(f"Resampling the audio from {fs_file} Hz to {fs_model} Hz") + tf = torchaudio.transforms.Resample( + orig_freq=fs_file, new_freq=fs_model + ).to(self.device) + batch = batch.mean(dim=0, keepdim=True) + batch = tf(batch) + + est_sources = self.separate_batch(batch) + est_sources = ( + est_sources / est_sources.abs().max(dim=1, keepdim=True)[0] + ) + return est_sources + + def forward(self, mix): + """Runs separation on the input mix""" + return self.separate_batch(mix) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/speaker.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/speaker.py new file mode 100644 index 000000000..10bc087a5 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/speaker.py @@ -0,0 +1,133 @@ +"""Specifies the inference interfaces for speaker recognition modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import torch + +from speechbrain.inference.classifiers import EncoderClassifier + + +class SpeakerRecognition(EncoderClassifier): + """A ready-to-use model for speaker recognition. It can be used to + perform speaker verification with verify_batch(). + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> import torchaudio + >>> from speechbrain.inference.speaker import SpeakerRecognition + >>> # Model is downloaded from the speechbrain HuggingFace repo + >>> tmpdir = getfixture("tmpdir") + >>> verification = SpeakerRecognition.from_hparams( + ... source="speechbrain/spkrec-ecapa-voxceleb", + ... savedir=tmpdir, + ... ) + + >>> # Perform verification + >>> from speechbrain.dataio import audio_io + >>> signal, fs = audio_io.load("tests/samples/single-mic/example1.wav") + >>> signal2, fs = audio_io.load("tests/samples/single-mic/example2.flac") + >>> score, prediction = verification.verify_batch(signal, signal2) + """ + + MODULES_NEEDED = [ + "compute_features", + "mean_var_norm", + "embedding_model", + "mean_var_norm_emb", + ] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.similarity = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) + + def verify_batch( + self, wavs1, wavs2, wav1_lens=None, wav2_lens=None, threshold=0.25 + ): + """Performs speaker verification with cosine distance. + + It returns the score and the decision (0 different speakers, + 1 same speakers). + + Arguments + --------- + wavs1 : Torch.Tensor + torch.Tensor containing the speech waveform1 (batch, time). + Make sure the sample rate is fs=16000 Hz. + wavs2 : Torch.Tensor + torch.Tensor containing the speech waveform2 (batch, time). + Make sure the sample rate is fs=16000 Hz. + wav1_lens : Torch.Tensor + torch.Tensor containing the relative length for each sentence + in the length (e.g., [0.8 0.6 1.0]) + wav2_lens : Torch.Tensor + torch.Tensor containing the relative length for each sentence + in the length (e.g., [0.8 0.6 1.0]) + threshold : Float + Threshold applied to the cosine distance to decide if the + speaker is different (0) or the same (1). + + Returns + ------- + score + The score associated to the binary verification output + (cosine distance). + prediction + The prediction is 1 if the two signals in input are from the same + speaker and 0 otherwise. + """ + emb1 = self.encode_batch(wavs1, wav1_lens, normalize=False) + emb2 = self.encode_batch(wavs2, wav2_lens, normalize=False) + score = self.similarity(emb1, emb2) + return score, score > threshold + + def verify_files(self, path_x, path_y, **kwargs): + """Speaker verification with cosine distance + + Returns the score and the decision (0 different speakers, + 1 same speakers). + + Arguments + --------- + path_x : str + Path to file x + path_y : str + Path to file y + **kwargs : dict + Arguments to ``load_audio`` + + Returns + ------- + score + The score associated to the binary verification output + (cosine distance). + prediction + The prediction is 1 if the two signals in input are from the same + speaker and 0 otherwise. + """ + waveform_x = self.load_audio(path_x, **kwargs) + waveform_y = self.load_audio(path_y, **kwargs) + # Fake batches: + batch_x = waveform_x.unsqueeze(0) + batch_y = waveform_y.unsqueeze(0) + # Verify: + score, decision = self.verify_batch(batch_x, batch_y) + # Squeeze: + return score[0], decision[0] diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/text.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/text.py new file mode 100644 index 000000000..6e25c69d8 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/text.py @@ -0,0 +1,443 @@ +"""Specifies the inference interfaces for text-processing modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +from itertools import chain + +import torch + +from speechbrain.inference.interfaces import ( + EncodeDecodePipelineMixin, + Pretrained, +) + + +class GraphemeToPhoneme(Pretrained, EncodeDecodePipelineMixin): + """ + A pretrained model implementation for Grapheme-to-Phoneme (G2P) models + that take raw natural language text as an input and + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> text = ( + ... "English is tough. It can be understood " + ... "through thorough thought though" + ... ) + >>> from speechbrain.inference.text import GraphemeToPhoneme + >>> tmpdir = getfixture("tmpdir") + >>> g2p = GraphemeToPhoneme.from_hparams( + ... "path/to/model", savedir=tmpdir + ... ) # doctest: +SKIP + >>> phonemes = g2p.g2p(text) # doctest: +SKIP + """ + + INPUT_STATIC_KEYS = ["txt"] + OUTPUT_KEYS = ["phonemes"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.create_pipelines() + self.load_dependencies() + + @property + def phonemes(self): + """Returns the available phonemes""" + return self.hparams.phonemes + + @property + def language(self): + """Returns the language for which this model is available""" + return self.hparams.language + + def g2p(self, text): + """Performs the Grapheme-to-Phoneme conversion + + Arguments + --------- + text: str or list[str] + a single string to be encoded to phonemes - or a + sequence of strings + + Returns + ------- + result: list + if a single example was provided, the return value is a + single list of phonemes + """ + single = isinstance(text, str) + if single: + text = [text] + + encoded_inputs = self.encode_input({"txt": text}) + self._update_graphemes(encoded_inputs) + + model_inputs = encoded_inputs + if hasattr(self.hparams, "model_input_keys"): + model_inputs = { + k: model_inputs[k] for k in self.hparams.model_input_keys + } + + model_outputs = self.mods.model(**model_inputs) + decoded_output = self.decode_output(model_outputs) + phonemes = decoded_output["phonemes"] + phonemes = self._remove_eos(phonemes) + if single: + phonemes = phonemes[0] + return phonemes + + def _remove_eos(self, phonemes): + """Removes the EOS character from the end of the sequence, + if encountered + + Arguments + --------- + phonemes : list + a list of phomemic transcriptions + + Returns + ------- + result : list + phonemes, without + """ + return [ + item[:-1] if item and item[-1] == "" else item + for item in phonemes + ] + + def _update_graphemes(self, model_inputs): + grapheme_sequence_mode = self.hparams.grapheme_sequence_mode + if grapheme_sequence_mode and grapheme_sequence_mode != "raw": + grapheme_encoded_key = f"grapheme_encoded_{grapheme_sequence_mode}" + if grapheme_encoded_key in model_inputs: + model_inputs["grapheme_encoded"] = model_inputs[ + grapheme_encoded_key + ] + + def load_dependencies(self): + """Loads any relevant model dependencies""" + deps_pretrainer = getattr(self.hparams, "deps_pretrainer", None) + if deps_pretrainer: + deps_pretrainer.collect_files() + deps_pretrainer.load_collected() + + def __call__(self, text): + """A convenience callable wrapper - same as G2P + + Arguments + --------- + text: str or list[str] + a single string to be encoded to phonemes - or a + sequence of strings + + Returns + ------- + result: list + if a single example was provided, the return value is a + single list of phonemes + """ + return self.g2p(text) + + def forward(self, noisy, lengths=None): + """Runs enhancement on the noisy input""" + return self.enhance_batch(noisy, lengths) + + +class ResponseGenerator(Pretrained): + """A ready-to-use Response Generator model + + The class can be used to generate and continue dialogue given the user input. + The given YAML must contain the fields specified in the *_NEEDED[] lists. + It needs to be used with custom.py to load the expanded model with added tokens like bos,eos, and speaker's tokens. + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + """ + + MODULES_NEEDED = ["model"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Load model + self.model = self.hparams.model + self.tokenizer = self.model.tokenizer + self.history_window = 2 * self.hparams.max_history + 1 + self.history = [] + + def generate_response(self, turn): + """ + Complete a dialogue given the user's input. + Arguments + --------- + turn: str + User input which is the last turn of the dialogue. + + Returns + ------- + response + Generated response for the user input based on the dialogue history. + """ + + self.history.append(turn) + inputs = self.prepare_input() + hyps = self.generate(inputs) + predicted_words = self.model.tokenizer.batch_decode( + hyps[:, inputs[0].shape[1] :], + skip_special_tokens=True, + clean_up_tokenization_spaces=True, + ) + response = predicted_words[0] + self.history.append(response) + return response + + def prepare_input(self): + """Users should modify this function according to their own tasks.""" + raise NotImplementedError + + def generate(self): + """Users should modify this function according to their own tasks.""" + raise NotImplementedError + + +class GPTResponseGenerator(ResponseGenerator): + """A ready-to-use Response Generator model + + The class can be used to generate and continue dialogue given the user input. + The given YAML must contain the fields specified in the *_NEEDED[] lists. + It needs to be used with custom.py to load the expanded GPT model with added tokens like bos,eos, and speaker's tokens. + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> from speechbrain.inference.text import GPTResponseGenerator + + >>> tmpdir = getfixture("tmpdir") + >>> res_gen_model = GPTResponseGenerator.from_hparams( + ... source="speechbrain/MultiWOZ-GPT-Response_Generation", + ... pymodule_file="custom.py", + ... ) # doctest: +SKIP + >>> response = res_gen_model.generate_response( + ... "I want to book a table for dinner" + ... ) # doctest: +SKIP + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # convert special tokens to their ids + ( + self.bos, + self.eos, + self.system, + self.user, + ) = self.model.tokenizer.convert_tokens_to_ids( + self.hparams.special_tokens + ) + + def generate(self, inputs): + """ + Complete a dialogue given the user's input. + + Arguments + --------- + inputs: tuple + history_bos which is the tokenized history+input values with appropriate speaker token appended before each turn and history_token_type which determines + the type of each token based on who is uttered that token (either User or System). + + Returns + ------- + response + Generated hypothesis for the user input based on the dialogue history. + """ + + history_bos, history_token_type = inputs + padding_mask = ~self.hparams.padding_mask( + history_bos, pad_idx=self.model.tokenizer.unk_token_id + ) + hyps = self.model.generate( + history_bos.detach(), + history_token_type.detach(), + padding_mask.detach(), + "beam", + ) + return hyps + + def prepare_input(self): + """Convert user input and previous histories to the format acceptable for GPT model. + It appends all previous history and input and truncates it based on max_history value. + It then tokenizes the input and generates additional input that determines the type of each token (System or User). + + Returns + ------- + history_bos: torch.Tensor + Tokenized history+input values with appropriate speaker token appended before each turn. + history_token_type: torch.LongTensor + Type of each token based on who is uttered that token (either User or System) + """ + history_tokens_lists = [ + self.model.tokenizer.encode(turn) for turn in self.history + ] + # add speaker tokens to the history turns (user is even, system is odd) + # BEFORE: [Hi how are you?], [I'm fine, thanks] + # AFTER: [SPK_1 Hi how are you?], [SPK_2 I'm fine, thanks] + history_input_lists = [ + [self.user if i % 2 == 0 else self.system] + encoded_turn + for i, encoded_turn in enumerate(history_tokens_lists) + ] + history_ids = history_input_lists[-self.history_window :] + # concatenate every token into a single list + # list(chain(*[[1, 2], [3, 4], [5]])) + # >>> [1, 2, 3, 4, 5] + history_ids = torch.LongTensor(list(chain(*history_ids))) + # create bos version for the input + history_bos = torch.cat( + (torch.tensor([self.bos]), history_ids, torch.tensor([self.system])) + ) + # create a mapping that associates each token in the input to a speaker + # INPUT: [SPK_1 Hi how are you? ], [SPK_2 I'm fine, thanks] + # TYPE: [SPK_1 SPK_1 SPK_1 SPK_1 SPK_1], [SPK_2 SPK_2 SPK_2 SPK_2 ] + history_token_type_lists = [ + [self.user if i % 2 == 0 else self.system] * len(encoded_turn) + for i, encoded_turn in enumerate(history_input_lists) + ] + history_token_type = torch.LongTensor( + list( + chain( + *( + [[self.system]] + + history_token_type_lists[-self.history_window :] + + [[self.system]] + ) + ) + ) + ) + return history_bos.unsqueeze(0), history_token_type.unsqueeze(0) + + +class Llama2ResponseGenerator(ResponseGenerator): + """A ready-to-use Response Generator model + + The class can be used to generate and continue dialogue given the user input. + The given YAML must contain the fields specified in the *_NEEDED[] lists. + It needs to be used with custom.py to load the expanded Llama2 model with added tokens like bos,eos, and speaker's tokens. + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> from speechbrain.inference.text import Llama2ResponseGenerator + + >>> tmpdir = getfixture("tmpdir") + >>> res_gen_model = Llama2ResponseGenerator.from_hparams( + ... source="speechbrain/MultiWOZ-Llama2-Response_Generation", + ... pymodule_file="custom.py", + ... ) # doctest: +SKIP + >>> response = res_gen_model.generate_response( + ... "I want to book a table for dinner" + ... ) # doctest: +SKIP + """ + + def __init__(self, *args, **kwargs): + run_opts = {"device": "cuda"} + super().__init__(run_opts=run_opts, *args, **kwargs) + # self.model = self.model#.to("cuda") + + def generate(self, inputs): + """ + Complete a dialogue given the user's input. + Arguments + --------- + inputs: prompt_bos + prompted inputs to be passed to llama2 model for generation. + + Returns + ------- + response + Generated hypothesis for the user input based on the dialogue history. + """ + prompt_bos = inputs[0].to(self.model.model.device) + padding_mask = ~self.hparams.padding_mask( + prompt_bos, pad_idx=self.tokenizer.pad_token_id + ) + hyps = self.model.generate( + prompt_bos.detach(), + padding_mask.detach(), + "beam", + ) + return hyps + + def prepare_input(self): + """Convert user input and previous histories to the format acceptable for Llama2 model. + It appends all previous history and input and truncates it based on max_history value. + It then tokenizes the input and add prompts. + + Returns + ------- + prompt_bos: torch.Tensor + Tokenized history+input values with appropriate prompt. + """ + + def generate_prompt(idx_and_item): + """add [INST] and [/INST] prompt to the start and end ogf item. + + Arguments + --------- + idx_and_item: tuple + id and its corresponding text. If the id is even, it is user turn and [ INST] is added. + + Returns + ------- + prompt_bos: torch.LongTensor + prompted text for one item. + """ + index, item = idx_and_item + if index % 2 == 0: + return "[INST] " + item + " [/INST]" + else: + return item + + prompts = list(map(generate_prompt, enumerate(self.history))) + + # encode each turn of the history + prompt_tokens_lists = [self.tokenizer.encode(turn) for turn in prompts] + + prompt_ids = prompt_tokens_lists[-self.history_window :] + # concatenate every token into a single list + # list(chain(*[[1, 2], [3, 4], [5]])) + # >>> [1, 2, 3, 4, 5] + prompt_ids = torch.LongTensor(list(chain(*prompt_ids))) + # without bos for lm_labels + + # # create bos version for the input + prompt_bos = torch.cat( + (torch.tensor([self.tokenizer.bos_token_id]), prompt_ids) + ) + return prompt_bos.unsqueeze(0).unsqueeze(dim=0) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/vocoders.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/vocoders.py new file mode 100644 index 000000000..d64a4f9a6 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/inference/vocoders.py @@ -0,0 +1,399 @@ +"""Specifies the inference interfaces for Text-To-Speech (TTS) modules. + +Authors: + * Aku Rouhe 2021 + * Peter Plantinga 2021 + * Loren Lugosch 2020 + * Mirco Ravanelli 2020 + * Titouan Parcollet 2021 + * Abdel Heba 2021 + * Andreas Nautsch 2022, 2023 + * Pooneh Mousavi 2023 + * Sylvain de Langen 2023 + * Adel Moumen 2023 + * Pradnya Kandarkar 2023 +""" + +import torch + +from speechbrain.dataio.dataio import length_to_mask +from speechbrain.inference.interfaces import Pretrained +from speechbrain.utils.logger import get_logger + +logger = get_logger(__name__) + + +class HIFIGAN(Pretrained): + """ + A ready-to-use wrapper for HiFiGAN (mel_spec -> waveform). + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + + Example + ------- + >>> tmpdir_vocoder = getfixture("tmpdir") / "vocoder" + >>> hifi_gan = HIFIGAN.from_hparams( + ... source="speechbrain/tts-hifigan-ljspeech", savedir=tmpdir_vocoder + ... ) + >>> mel_specs = torch.rand(2, 80, 298) + >>> waveforms = hifi_gan.decode_batch(mel_specs) + >>> # You can use the vocoder coupled with a TTS system + >>> # Initialize TTS (tacotron2) + >>> tmpdir_tts = getfixture("tmpdir") / "tts" + >>> from speechbrain.inference.TTS import Tacotron2 + >>> tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir=tmpdir_tts) + >>> # Running the TTS + >>> mel_output, mel_length, alignment = tacotron2.encode_text("Mary had a little lamb") + >>> # Running Vocoder (spectrogram-to-waveform) + >>> waveforms = hifi_gan.decode_batch(mel_output) + """ + + HPARAMS_NEEDED = ["generator"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.infer = self.hparams.generator.inference + self.first_call = True + + def decode_batch(self, spectrogram, mel_lens=None, hop_len=None): + """Computes waveforms from a batch of mel-spectrograms + + Arguments + --------- + spectrogram: torch.Tensor + Batch of mel-spectrograms [batch, mels, time] + mel_lens: torch.tensor + A list of lengths of mel-spectrograms for the batch + Can be obtained from the output of Tacotron/FastSpeech + hop_len: int + hop length used for mel-spectrogram extraction + should be the same value as in the .yaml file + + Returns + ------- + waveforms: torch.Tensor + Batch of mel-waveforms [batch, 1, time] + """ + # Prepare for inference by removing the weight norm + if self.first_call: + self.hparams.generator.remove_weight_norm() + self.first_call = False + with torch.no_grad(): + waveform = self.infer(spectrogram.to(self.device)) + + # Mask the noise caused by padding during batch inference + if mel_lens is not None and hop_len is not None: + waveform = self.mask_noise(waveform, mel_lens, hop_len) + + return waveform + + def mask_noise(self, waveform, mel_lens, hop_len): + """Mask the noise caused by padding during batch inference + + Arguments + --------- + waveform: torch.tensor + Batch of generated waveforms [batch, 1, time] + mel_lens: torch.tensor + A list of lengths of mel-spectrograms for the batch + Can be obtained from the output of Tacotron/FastSpeech + hop_len: int + hop length used for mel-spectrogram extraction + same value as in the .yaml file + + Returns + ------- + waveform: torch.tensor + Batch of waveforms without padded noise [batch, 1, time] + """ + waveform = waveform.squeeze(1) + # the correct audio length should be hop_len * mel_len + mask = length_to_mask( + mel_lens * hop_len, waveform.shape[1], device=waveform.device + ).bool() + waveform.masked_fill_(~mask, 0.0) + return waveform.unsqueeze(1) + + def decode_spectrogram(self, spectrogram): + """Computes waveforms from a single mel-spectrogram + + Arguments + --------- + spectrogram: torch.Tensor + mel-spectrogram [mels, time] + + Returns + ------- + waveform: torch.Tensor + waveform [1, time] + audio can be saved by: + >>> from speechbrain.dataio import audio_io + >>> waveform = torch.rand(1, 666666) + >>> sample_rate = 22050 + >>> audio_io.save( + ... str(getfixture("tmpdir") / "test.wav"), waveform, sample_rate + ... ) + """ + if self.first_call: + self.hparams.generator.remove_weight_norm() + self.first_call = False + with torch.no_grad(): + waveform = self.infer(spectrogram.unsqueeze(0).to(self.device)) + return waveform.squeeze(0) + + def forward(self, spectrogram): + "Decodes the input spectrograms" + return self.decode_batch(spectrogram) + + +class DiffWaveVocoder(Pretrained): + """ + A ready-to-use inference wrapper for DiffWave as vocoder. + The wrapper allows to perform generative tasks: + locally-conditional generation: mel_spec -> waveform + + Arguments + --------- + *args : tuple + **kwargs : dict + Arguments are forwarded to ``Pretrained`` parent class. + """ + + HPARAMS_NEEDED = ["diffusion"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if hasattr(self.hparams, "diffwave"): + self.infer = self.hparams.diffusion.inference + else: + raise NotImplementedError + + def decode_batch( + self, + mel, + hop_len, + mel_lens=None, + fast_sampling=False, + fast_sampling_noise_schedule=None, + ): + """Generate waveforms from spectrograms + + Arguments + --------- + mel: torch.tensor + spectrogram [batch, mels, time] + hop_len: int + Hop length during mel-spectrogram extraction + Should be the same value as in the .yaml file + Used to determine the output wave length + Also used to mask the noise for vocoding task + mel_lens: torch.tensor + Used to mask the noise caused by padding + A list of lengths of mel-spectrograms for the batch + Can be obtained from the output of Tacotron/FastSpeech + fast_sampling: bool + whether to do fast sampling + fast_sampling_noise_schedule: list + the noise schedules used for fast sampling + Returns + ------- + waveforms: torch.tensor + Batch of mel-waveforms [batch, 1, time] + + """ + with torch.no_grad(): + waveform = self.infer( + unconditional=False, + scale=hop_len, + condition=mel.to(self.device), + fast_sampling=fast_sampling, + fast_sampling_noise_schedule=fast_sampling_noise_schedule, + ) + + # Mask the noise caused by padding during batch inference + if mel_lens is not None and hop_len is not None: + waveform = self.mask_noise(waveform, mel_lens, hop_len) + return waveform + + def mask_noise(self, waveform, mel_lens, hop_len): + """Mask the noise caused by padding during batch inference + + Arguments + --------- + waveform: torch.tensor + Batch of generated waveforms [batch, 1, time] + mel_lens: torch.tensor + A list of lengths of mel-spectrograms for the batch + Can be obtained from the output of Tacotron/FastSpeech + hop_len: int + hop length used for mel-spectrogram extraction + same value as in the .yaml file + + Returns + ------- + waveform: torch.tensor + Batch of waveforms without padded noise [batch, 1, time] + """ + waveform = waveform.squeeze(1) + # the correct audio length should be hop_len * mel_len + mask = length_to_mask( + mel_lens * hop_len, waveform.shape[1], device=waveform.device + ).bool() + waveform.masked_fill_(~mask, 0.0) + return waveform.unsqueeze(1) + + def decode_spectrogram( + self, + spectrogram, + hop_len, + fast_sampling=False, + fast_sampling_noise_schedule=None, + ): + """Computes waveforms from a single mel-spectrogram + + Arguments + --------- + spectrogram: torch.tensor + mel-spectrogram [mels, time] + hop_len: int + hop length used for mel-spectrogram extraction + same value as in the .yaml file + fast_sampling: bool + whether to do fast sampling + fast_sampling_noise_schedule: list + the noise schedules used for fast sampling + + Returns + ------- + waveform: torch.tensor + waveform [1, time] + + audio can be saved by: + >>> from speechbrain.dataio import audio_io + >>> waveform = torch.rand(1, 666666) + >>> sample_rate = 22050 + >>> audio_io.save( + ... str(getfixture("tmpdir") / "test.wav"), waveform, sample_rate + ... ) + """ + with torch.no_grad(): + waveform = self.infer( + unconditional=False, + scale=hop_len, + condition=spectrogram.unsqueeze(0).to(self.device), + fast_sampling=fast_sampling, + fast_sampling_noise_schedule=fast_sampling_noise_schedule, + ) + return waveform.squeeze(0) + + def forward(self, spectrogram): + """Decodes the input spectrograms""" + return self.decode_batch(spectrogram) + + +class UnitHIFIGAN(Pretrained): + """ + A ready-to-use wrapper for Unit HiFiGAN (discrete units -> waveform). + + Arguments + --------- + *args : tuple + See `Pretrained` + **kwargs : dict + See `Pretrained` + + Example + ------- + >>> tmpdir_vocoder = getfixture("tmpdir") / "vocoder" + >>> hifi_gan = UnitHIFIGAN.from_hparams( + ... source="speechbrain/hifigan-hubert-l1-3-7-12-18-23-k1000-LibriTTS", + ... savedir=tmpdir_vocoder, + ... ) + >>> codes = torch.randint(0, 99, (100, 1)) + >>> waveform = hifi_gan.decode_unit(codes) + """ + + HPARAMS_NEEDED = ["generator"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.infer = self.hparams.generator.inference + self.first_call = True + # Temporary fix for mapping indices from the range [0, k] to [1, k+1] + self.tokenize = True + + def decode_batch(self, units, spk=None): + """Computes waveforms from a batch of discrete units + + Arguments + --------- + units: torch.tensor + Batch of discrete units [batch, codes] + spk: torch.tensor + Batch of speaker embeddings [batch, spk_dim] + + Returns + ------- + waveforms: torch.tensor + Batch of mel-waveforms [batch, 1, time] + """ + # Remove weight norm for inference if it's the first call + if self.first_call: + self.hparams.generator.remove_weight_norm() + self.first_call = False + + # Ensure that the units sequence has a length of at least 3 + if units.size(1) < 3: + raise ValueError( + "The 'units' argument should have a length of at least 3 because of padding size." + ) + + # Increment units if tokenization is enabled + if self.tokenize: + units += 1 + if spk is not None: + spk = spk.to(self.device) + with torch.no_grad(): + waveform = self.infer(units.to(self.device), spk=spk) + return waveform + + def decode_unit(self, units, spk=None): + """Computes waveforms from a single sequence of discrete units + Arguments + --------- + units: torch.tensor + codes: [time] + spk: torch.tensor + spk: [spk_dim] + Returns + ------- + waveform: torch.tensor + waveform [1, time] + """ + # Remove weight norm for inference if it's the first call + if self.first_call: + self.hparams.generator.remove_weight_norm() + self.first_call = False + + # Ensure that the units sequence has a length of at least 4 + if units.size(0) < 4: + raise ValueError( + "The 'units' argument should have a length of at least 4 because of padding size." + ) + + # Increment units if tokenization is enabled + if self.tokenize: + units = units + 1 + if spk is not None: + spk = spk.unsqueeze(0).to(self.device) + with torch.no_grad(): + waveform = self.infer(units.unsqueeze(0).to(self.device), spk=spk) + return waveform.squeeze(0) + + def forward(self, units, spk=None): + "Decodes the input units" + return self.decode_batch(units, spk=spk) diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/README.md b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/README.md new file mode 100644 index 000000000..d4f69cabb --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/README.md @@ -0,0 +1,33 @@ +Third-Party Integrations +------------------------ + +This python module serves to collect all the (non-recipe) SpeechBrain code that relies on +external libraries not present in the explicit dependency list in `pyproject.toml` (and `requirements.txt`). +By keeping the dependency list as small as possible we keep SpeechBrain lightweight and easy to maintain. +In addition, this folder makes it easier to keep track of what third-party tools have been +added and apply different rules to the adding and maintenance of new external integrations. + +> [!WARNING] +> Since these third-party integrations rely on libraries not part of the core toolkit, we make +> no guarantees as to the proper functioning of these libraries; they may be +> broken on the develop branch at any time. We will check that they function correctly +> only when creating a new release of the toolkit. + +In order to minimize the impact of libraries changing and causing the integrations +to stop functioning, we will add additional tests and checks on code in this module. +If the tests are broken, we may remove rather than fix the code in this integration +depending on our capacity. + +To add new code to the module, please ensure it contains runnable examples in the docstring +and tests in the `integrations/tests` folder. You can check that all the tests pass by running + +```bash +$ sh tests/.third-party-tests.sh +``` + +In addition we would like new modules to have 80% or greater coverage of the code, evaluated +using the following code, with `pytest-cov` installed: + +```bash +$ pytest --cov=speechbrain/integrations --cov-context=test --doctest-modules speechbrain/integrations +``` diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/__init__.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/__init__.py new file mode 100644 index 000000000..179ceec69 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/__init__.py @@ -0,0 +1,7 @@ +""" +Package for code with additional dependencies. + +Any code with dependencies beyond those explicitly listed in the `pyproject.toml` or `requirements.txt` file +is typically added in a sub-module within this `integrations` module with a `README.md` explaining the +dependency. +""" diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/README.md b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/README.md new file mode 100644 index 000000000..9daa94512 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/README.md @@ -0,0 +1,31 @@ +Alignment +--------- + +This folder contains code for doing speech alignment using the [CTC Segmentation library](https://github.com/lumaku/ctc-segmentation) + +Here is a record of test setup and relevant results: + +```bash +$ pip install ctc-segmentation==1.7.4 numpy<2.0 +$ pytest --cov=speechbrain/integrations/alignment/ --cov-context=test --doctest-modules speechbrain/integrations/alignment/ + +=================== test session starts ======================= +platform linux -- Python 3.11.11, pytest-7.4.0, pluggy-1.5.0 +configfile: pytest.ini +plugins: anyio-4.8.0, hydra-core-1.3.2, cov-6.1.1, typeguard-4.4.1 +collected 9 items + +speechbrain/integrations/alignment/ctc_seg.py . +speechbrain/integrations/alignment/diarization.py ........ + +============================ tests coverage =========================== +__________ coverage: platform linux, python 3.11.11-final-0 ___________ + +Name Stmts Miss Cover +----------------------------------------------------------------------- +speechbrain/integrations/alignment/ctc_seg.py 191 54 72% +speechbrain/integrations/alignment/diarization.py 317 133 58% +----------------------------------------------------------------------- +TOTAL 508 187 63% + +``` diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/__init__.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/__init__.py new file mode 100644 index 000000000..42695e7b5 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/__init__.py @@ -0,0 +1,3 @@ +""" +Package for speech alignment using the CTC Segmentation library. +""" diff --git a/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/ctc_seg.py b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/ctc_seg.py new file mode 100644 index 000000000..2c16ff9d2 --- /dev/null +++ b/runtime/ops/mapper/audio_asr_pipeline/audio_preprocessor/local_libs/speechbrain/speechbrain/integrations/alignment/ctc_seg.py @@ -0,0 +1,675 @@ +#!/usr/bin/env python3 +"""Perform CTC segmentation to align utterances within audio files. + +This uses the ctc-segmentation Python package. +Install it with pip or see the installing instructions in +https://github.com/lumaku/ctc-segmentation + +Authors + * Ludwig Kürzinger 2021 +""" + +from pathlib import Path +from types import SimpleNamespace +from typing import List, Optional, Union + +import numpy as np +import torch + +# speechbrain interface +from speechbrain.inference.ASR import EncoderASR, EncoderDecoderASR +from speechbrain.utils.logger import get_logger + +# imports for CTC segmentation +try: + from ctc_segmentation import ( + CtcSegmentationParameters, + ctc_segmentation, + determine_utterance_segments, + prepare_text, + prepare_token_list, + ) +except ImportError: + print( + "ImportError: " + "Is the ctc_segmentation module installed " + "and in your PYTHONPATH?" + ) + raise ImportError("The ctc_segmentation module is missing.") + +logger = get_logger(__name__) + + +class CTCSegmentationTask(SimpleNamespace): + """Task object for CTC segmentation. + + This object is automatically generated and acts as + a container for results of a CTCSegmentation object. + + When formatted with str(·), this object returns + results in a kaldi-style segments file formatting. + The human-readable output can be configured with + the printing options. + + Attributes + ---------- + text : list + Utterance texts, separated by line. But without the utterance + name at the beginning of the line (as in kaldi-style text). + ground_truth_mat : array + Ground truth matrix (CTC segmentation). + utt_begin_indices : np.ndarray + Utterance separator for the Ground truth matrix. + timings : np.ndarray + Time marks of the corresponding chars. + state_list : list + Estimated alignment of chars/tokens. + segments : list + Calculated segments as: (start, end, confidence score). + config : CtcSegmentationParameters + CTC Segmentation configuration object. + name : str + Name of aligned audio file (Optional). If given, name is + considered when generating the text. + Default: "utt". + utt_ids : list + The list of utterance names (Optional). This list should + have the same length as the number of utterances. + lpz : np.ndarray + CTC posterior log probabilities (Optional). + print_confidence_score : bool + Include the confidence score. + Default: True. + print_utterance_text : bool + Include utterance text. + Default: True. + + """ + + text = None + ground_truth_mat = None + utt_begin_indices = None + timings = None + char_probs = None + state_list = None + segments = None + config = None + done = False + # Optional + name = "utt" + utt_ids = None + lpz = None + # Printing + print_confidence_score = True + print_utterance_text = True + + def set(self, **kwargs): + """Update object attributes.""" + self.__dict__.update(kwargs) + + def __str__(self): + """Return a kaldi-style ``segments`` file (string).""" + output = "" + num_utts = len(self.segments) + if self.utt_ids is None: + utt_names = [f"{self.name}_{i:04}" for i in range(num_utts)] + else: + # ensure correct mapping of segments to utterance ids + assert num_utts == len(self.utt_ids) + utt_names = self.utt_ids + for i, boundary in enumerate(self.segments): + # utterance name and file name + utt_entry = f"{utt_names[i]} {self.name} " + # segment start and end + utt_entry += f"{boundary[0]:.2f} {boundary[1]:.2f}" + # confidence score + if self.print_confidence_score: + utt_entry += f" {boundary[2]:3.4f}" + # utterance ground truth + if self.print_utterance_text: + utt_entry += f" {self.text[i]}" + output += utt_entry + "\n" + return output + + +class CTCSegmentation: + """Align text to audio using CTC segmentation. + + Usage: Initialize with given ASR model and parameters. + If needed, parameters for CTC segmentation can be set with ``set_config(·)``. + Then call the instance as function to align text within an audio file. + + Arguments + --------- + asr_model : EncoderDecoderASR + Speechbrain ASR interface. This requires a model that has a + trained CTC layer for inference. It is better to use a model with + single-character tokens to get a better time resolution. + Please note that the inference complexity with Transformer models + usually increases quadratically with audio length. + It is therefore recommended to use RNN-based models, if available. + kaldi_style_text : bool + A kaldi-style text file includes the name of the + utterance at the start of the line. If True, the utterance name + is expected as first word at each line. If False, utterance + names are automatically generated. Set this option according to + your input data. Default: True. + text_converter : str + How CTC segmentation handles text. + "tokenize": Use the ASR model tokenizer to tokenize the text. + "classic": The text is preprocessed as text pieces which takes + token length into account. If the ASR model has longer tokens, + this option may yield better results. Default: "tokenize". + time_stamps : str + Choose the method how the time stamps are + calculated. While "fixed" and "auto" use both the sample rate, + the ratio of samples to one frame is either automatically + determined for each inference or fixed at a certain ratio that + is initially determined by the module, but can be changed via + the parameter ``samples_to_frames_ratio``. Recommended for + longer audio files: "auto". + **ctc_segmentation_args + Parameters for CTC segmentation. + The full list of parameters is found in ``set_config``. + + Example + ------- + >>> # using example file included in the SpeechBrain repository + >>> from speechbrain.inference.ASR import EncoderDecoderASR + >>> # load an ASR model + >>> pre_trained = "speechbrain/asr-transformer-transformerlm-librispeech" + >>> asr_model = EncoderDecoderASR.from_hparams(source=pre_trained) + >>> aligner = CTCSegmentation(asr_model, kaldi_style_text=False) + >>> # load data + >>> audio_path = "tests/samples/single-mic/example1.wav" + >>> text = ["THE BIRCH CANOE", "SLID ON THE", "SMOOTH PLANKS"] + >>> segments = aligner(audio_path, text, name="example1") + + On multiprocessing + ------------------ + To parallelize the computation with multiprocessing, these three steps + can be separated: + (1) ``get_lpz``: obtain the lpz, + (2) ``prepare_segmentation_task``: prepare the task, and + (3) ``get_segments``: perform CTC segmentation. + Note that the function `get_segments` is a static method and therefore + independent of an already initialized CTCSegmentation object. + + References + ---------- + CTC-Segmentation of Large Corpora for German End-to-end Speech Recognition + 2020, Kürzinger, Winkelbauer, Li, Watzel, Rigoll + https://arxiv.org/abs/2007.09127 + + More parameters are described in https://github.com/lumaku/ctc-segmentation + """ + + fs = 16000 + kaldi_style_text = True + samples_to_frames_ratio = None + time_stamps = "auto" + choices_time_stamps = ["auto", "fixed"] + text_converter = "tokenize" + choices_text_converter = ["tokenize", "classic"] + warned_about_misconfiguration = False + config = CtcSegmentationParameters() + + def __init__( + self, + asr_model: Union[EncoderASR, EncoderDecoderASR], + kaldi_style_text: bool = True, + text_converter: str = "tokenize", + time_stamps: str = "auto", + **ctc_segmentation_args, + ): + # Prepare ASR model + if ( + isinstance(asr_model, EncoderDecoderASR) + and not ( + hasattr(asr_model, "mods") + and hasattr(asr_model.mods, "decoder") + and hasattr(asr_model.mods.decoder, "ctc_weight") + ) + ) or ( + isinstance(asr_model, EncoderASR) + and not ( + hasattr(asr_model, "mods") + and hasattr(asr_model.mods, "encoder") + and hasattr(asr_model.mods.encoder, "ctc_lin") + ) + ): + raise AttributeError("The given asr_model has no CTC module!") + if not hasattr(asr_model, "tokenizer"): + raise AttributeError( + "The given asr_model has no tokenizer in asr_model.tokenizer!" + ) + self.asr_model = asr_model + self._encode = self.asr_model.encode_batch + + if isinstance(asr_model, EncoderDecoderASR): + if not hasattr(self.asr_model.hparams, "scorer"): + raise AttributeError( + "``ScorerBuilder`` module is required for CTC segmentation." + ) + + if "ctc" not in self.asr_model.hparams.scorer.full_scorers: + raise AttributeError( + "``CTCScorer`` module is required for CTC segmentation." + ) + + def ctc_forward_step(x: torch.Tensor) -> torch.Tensor: + """Forward step for CTC module.""" + module = self.asr_model.hparams.scorer.full_scorers["ctc"] + logits = module.ctc_fc(x) + log_probs = module.softmax(logits) + return log_probs + + self._ctc = ctc_forward_step + else: + # Apply log-softmax to encoder output + self._ctc = self.asr_model.hparams.log_softmax + self._tokenizer = self.asr_model.tokenizer + + # Apply configuration + self.set_config( + fs=self.asr_model.hparams.sample_rate, + time_stamps=time_stamps, + kaldi_style_text=kaldi_style_text, + text_converter=text_converter, + **ctc_segmentation_args, + ) + + # determine token or character list + char_list = [ + asr_model.tokenizer.id_to_piece(i) + for i in range(asr_model.tokenizer.vocab_size()) + ] + self.config.char_list = char_list + + # Warn about possible misconfigurations + max_char_len = max([len(c) for c in char_list]) + if len(char_list) > 500 and max_char_len >= 8: + logger.warning( + f"The dictionary has {len(char_list)} tokens with " + f"a max length of {max_char_len}. This may lead " + f"to low alignment performance and low accuracy." + ) + + def set_config( + self, + time_stamps: Optional[str] = None, + fs: Optional[int] = None, + samples_to_frames_ratio: Optional[float] = None, + set_blank: Optional[int] = None, + replace_spaces_with_blanks: Optional[bool] = None, + kaldi_style_text: Optional[bool] = None, + text_converter: Optional[str] = None, + gratis_blank: Optional[bool] = None, + min_window_size: Optional[int] = None, + max_window_size: Optional[int] = None, + scoring_length: Optional[int] = None, + ): + """Set CTC segmentation parameters. + + Parameters for timing + --------------------- + time_stamps : str + Select method how CTC index duration is estimated, and + thus how the time stamps are calculated. + fs : int + Sample rate. Usually derived from ASR model; use this parameter + to overwrite the setting. + samples_to_frames_ratio : float + If you want to directly determine the + ratio of samples to CTC frames, set this parameter, and + set ``time_stamps`` to "fixed". + Note: If you want to calculate the time stamps from a model + with fixed subsampling, set this parameter to: + ``subsampling_factor * frame_duration / 1000``. + + Parameters for text preparation + ------------------------------- + set_blank : int + Index of blank in token list. Default: 0. + replace_spaces_with_blanks : bool + Inserts blanks between words, which is + useful for handling long pauses between words. Only used in + ``text_converter="classic"`` preprocessing mode. Default: False. + kaldi_style_text : bool + Determines whether the utterance name is expected + as fist word of the utterance. Set at module initialization. + text_converter : str + How CTC segmentation handles text. + Set at module initialization. + + Parameters for alignment + ------------------------ + min_window_size : int + Minimum number of frames considered for a single + utterance. The current default value of 8000 corresponds to + roughly 4 minutes (depending on ASR model) and should be OK in + most cases. If your utterances are further apart, increase + this value, or decrease it for smaller audio files. + max_window_size : int + Maximum window size. It should not be necessary + to change this value. + gratis_blank : bool + If True, the transition cost of blank is set to zero. + Useful for long preambles or if there are large unrelated segments + between utterances. Default: False. + + Parameters for calculation of confidence score + ---------------------------------------------- + scoring_length : int + Block length to calculate confidence score. The + default value of 30 should be OK in most cases. + 30 corresponds to roughly 1-2s of audio. + """ + # Parameters for timing + if time_stamps is not None: + if time_stamps not in self.choices_time_stamps: + raise NotImplementedError( + f"Parameter ´time_stamps´ has to be one of " + f"{list(self.choices_time_stamps)}", + ) + self.time_stamps = time_stamps + if fs is not None: + self.fs = float(fs) + if samples_to_frames_ratio is not None: + self.samples_to_frames_ratio = float(samples_to_frames_ratio) + # Parameters for text preparation + if set_blank is not None: + self.config.blank = int(set_blank) + if replace_spaces_with_blanks is not None: + self.config.replace_spaces_with_blanks = bool( + replace_spaces_with_blanks + ) + if kaldi_style_text is not None: + self.kaldi_style_text = bool(kaldi_style_text) + if text_converter is not None: + if text_converter not in self.choices_text_converter: + raise NotImplementedError( + f"Parameter ´text_converter´ has to be one of " + f"{list(self.choices_text_converter)}", + ) + self.text_converter = text_converter + # Parameters for alignment + if min_window_size is not None: + self.config.min_window_size = int(min_window_size) + if max_window_size is not None: + self.config.max_window_size = int(max_window_size) + if gratis_blank is not None: + self.config.blank_transition_cost_zero = bool(gratis_blank) + if ( + self.config.blank_transition_cost_zero + and self.config.replace_spaces_with_blanks + and not self.warned_about_misconfiguration + ): + logger.error( + "Blanks are inserted between words, and also the transition cost of" + " blank is zero. This configuration may lead to misalignments!" + ) + self.warned_about_misconfiguration = True + # Parameter for calculation of confidence score + if scoring_length is not None: + self.config.score_min_mean_over_L = int(scoring_length) + + def get_timing_config(self, speech_len=None, lpz_len=None): + """Obtain parameters to determine time stamps.""" + timing_cfg = { + "index_duration": self.config.index_duration, + } + # As the parameter ctc_index_duration vetoes the other + if self.time_stamps == "fixed": + # Initialize the value, if not yet available + if self.samples_to_frames_ratio is None: + ratio = self.estimate_samples_to_frames_ratio() + self.samples_to_frames_ratio = ratio + index_duration = self.samples_to_frames_ratio / self.fs + else: + assert self.time_stamps == "auto" + samples_to_frames_ratio = speech_len / lpz_len + index_duration = samples_to_frames_ratio / self.fs + timing_cfg["index_duration"] = index_duration + return timing_cfg + + def estimate_samples_to_frames_ratio(self, speech_len=215040): + """Determine the ratio of encoded frames to sample points. + + This method helps to determine the time a single encoded frame occupies. + As the sample rate already gave the number of samples, only the ratio + of samples per encoded CTC frame are needed. This function estimates them by + doing one inference, which is only needed once. + + Arguments + --------- + speech_len : int + Length of randomly generated speech vector for single + inference. Default: 215040. + + Returns + ------- + int + Estimated ratio. + """ + random_input = torch.rand(speech_len) + lpz = self.get_lpz(random_input) + lpz_len = lpz.shape[0] + # CAVEAT assumption: Frontend does not discard trailing data! + samples_to_frames_ratio = speech_len / lpz_len + return samples_to_frames_ratio + + @torch.no_grad() + def get_lpz(self, speech: Union[torch.Tensor, np.ndarray]): + """Obtain CTC posterior log probabilities for given speech data. + + Arguments + --------- + speech : Union[torch.Tensor, np.ndarray] + Speech audio input. + + Returns + ------- + np.ndarray + Numpy vector with CTC log posterior probabilities. + """ + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + # Batch data: (Nsamples,) -> (1, Nsamples) + speech = speech.unsqueeze(0).to(self.asr_model.device) + wav_lens = torch.tensor([1.0]).to(self.asr_model.device) + enc = self._encode(speech, wav_lens) + # Apply ctc layer to obtain log character probabilities + lpz = self._ctc(enc).detach() + # Shape should be (