Infer error with Tacotron2+WaveNet

Hello TTS team, thanks a lot for open sourcing this project!

I’m a newbie to TTS DeepLearning stuff… so, I’m trying to get a hang of this by going through the process of training and inferring using LJSpeech sample dataset. I ran the training script using this notebook - https://gist.github.com/erogol/97516ad65b44dbddb8cd694953187c5b. It went well… I’m using Tacotron2 model. I stopped the training at 10K checkpoint. I want to go through the infer process of generating the wave file for an input sentence. I know that 10K iterations are not enough, but I just want to iron out end to end process before committing to performing full training (even if what I get is just static or gibberish). So, I’m now trying to perform infer using this notebook - https://colab.research.google.com/github/tugstugi/dl-colab-notebooks/blob/master/notebooks/Mozilla_TTS_WaveRNN.ipynb

I’m ending up with the following errors:

(this is the cell I’m executing; the last one)
align, spec, stop_tokens, wav = tts(model, SENTENCE, CONFIG, use_cuda, ap, speaker_id=0, use_gl=False, figures=False)

(error below)
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
1370 ret = torch.addmm(bias, input, weight.t())
1371 else:
-> 1372 output = input.matmul(weight.t())
1373 if bias is not None:
1374 output += bias

RuntimeError: size mismatch, m1: [1 x 560], m2: [80 x 256] at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:290

=====
I’m pretty much using standard config files and sample dataset (LJSpeech). Wondering what config changes should I make (to tts or wavenet), or how to overcome this error, for the 10K checkpoint model I have (Tacotron2)? Any help is appreciated. Thanks!

It could be one of any number of things, but I’d start by ensuring that you’re using consistent branches of the TTS repo for both notebooks

Thank you Neil. That seemed to be the problem. I checked out the right commit of TTS repo and things started working.

1 Like

@ttaong6 What was the right commit? Just had exactly the same error.