pip install auraloss
import torch
import auraloss
mrstft = auraloss.freq.MultiResolutionSTFTLoss()
input = torch.rand(8,1,44100)
target = torch.rand(8,1,44100)
loss = mrstft(input, target)
We categorize the loss functions as either timedomain or frequencydomain approaches. Additionally, we include perceptual transforms.
Loss function  Interface  Reference 

Time domain  
Errortosignal ratio (ESR)  auraloss.time.ESRLoss() 
Wright & Välimäki, 2019 
DC error (DC)  auraloss.time.DCLoss() 
Wright & Välimäki, 2019 
Log hyperbolic cosine (Logcosh)  auraloss.time.LogCoshLoss() 
Chen et al., 2019 
Signaltonoise ratio (SNR)  auraloss.time.SNRLoss() 

Scaleinvariant signaltodistortion ratio (SISDR) 
auraloss.time.SISDRLoss() 
Le Roux et al., 2018 
Scaledependent signaltodistortion ratio (SDSDR) 
auraloss.time.SDSDRLoss() 
Le Roux et al., 2018 
Frequency domain  
Aggregate STFT  auraloss.freq.STFTLoss() 
Arik et al., 2018 
Aggregate Melscaled STFT  auraloss.freq.MelSTFTLoss(sample_rate) 

Multiresolution STFT  auraloss.freq.MultiResolutionSTFTLoss() 
Yamamoto et al., 2019* 
Randomresolution STFT  auraloss.freq.RandomResolutionSTFTLoss() 
Steinmetz & Reiss, 2020 
Sum and difference STFT loss  auraloss.freq.SumAndDifferenceSTFTLoss() 
Steinmetz et al., 2020 
Perceptual transforms  
Sum and difference signal transform  auraloss.perceptual.SumAndDifference() 

FIR preemphasis filters  auraloss.perceptual.FIRFilter() 
Wright & Välimäki, 2019 
* Wang et al., 2019 also propose a multiresolution spectral loss (that Engel et al., 2020 follow), but they do not include both the log magnitude (L1 distance) and spectral convergence terms, introduced in Arik et al., 2018, and then extended for the multiresolution case in Yamamoto et al., 2019.
Currently we include an example using a set of the loss functions to train a TCN for modeling an analog dynamic range compressor.
For details please refer to the details in examples/compressor
.
We provide pretrained models, evaluation scripts to compute the metrics in the paper, as well as scripts to retrain models.
There are some more advanced things you can do based upon the STFTLoss
class.
For example, you can compute both linear and log scaled STFT errors as in Engel et al., 2020.
In this case we do not include the spectral convergence term.
stft_loss = auraloss.freq.STFTLoss(w_log_mag=1.0,
w_lin_mag=1.0,
w_sc=0.0, )
There is also a Melscaled STFT loss, which has some special requirements. This loss requires you set the sample rate as well as specify the correct device.
sample_rate = 44100
melstft_loss = auraloss.freq.MelSTFTLoss(sample_rate, device="cuda")
You can also build a multiresolution Melscaled STFT loss with 64 bins easily. Make sure you pass the correct device where the tensors you are comparing will be.
mrmelstft_loss = auraloss.freq.MultiResolutionSTFTLoss(scale="mel",
n_bins=64,
sample_rate=sample_rate,
device="cuda")
We currently have no tests, but those will also be coming soon, so use caution at the moment. Future loss functions to be included will target neural network based perceptual losses, which tend to be a bit more sophisticated than those we have included so far.
If you are interested in adding a loss function please make a pull request.
If you use this code in your work please consider citing us.
@inproceedings{steinmetz2020auraloss,
title={auraloss: {A}udio focused loss functions in {PyTorch}},
author={Steinmetz, Christian J. and Reiss, Joshua D.},
booktitle={Digital Music Research Network Oneday Workshop (DMRN+15)},
year={2020}}